# Importing Dependencies
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import seaborn as sns
sns.set_style('whitegrid')
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.nn.functional as Fun
import torch.optim as optim
Implement Logistic Regression using the Pyro library referring [1] for guidance.
Show both the mean prediction as well as standard deviation in the predictions over the 2d grid. Use NUTS MCMC sampling to sample the posterior. Take 1000 samples for posterior distribution and use 500 samples as burn/warm up. Use the below given dataset.
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
X, y = make_moons(n_samples=100, noise=0.3, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2, random_state=41)
# X_train = torch.tensor(X_train); y_train = torch.tensor(y_train)
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral)
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Dataset')
plt.show()
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
# Defining the logistic regression pyro model
def logistic_regression(X, y=None):
# The weights and bias are random variables, we assume the weights to be normally distributed
weight = pyro.sample("weight", dist.Normal(0, 1).expand([X.shape[1]]))
bias = pyro.sample("bias", dist.Normal(0, 1))
# Compute the logits
logits = torch.matmul(X, weight) + bias
# Model observation (likelihood)
# Repeated block of code repeated X.shape[0] times named "data"
with pyro.plate("data", X.shape[0]):
obs = pyro.sample("obs", dist.Bernoulli(logits=logits), obs=y)
# Running the model on the data
model = logistic_regression; num_samples = 1000; burn_in = 500
posterior = MCMC(NUTS(model), num_samples=num_samples, warmup_steps=burn_in)
posterior.run(X_train, y_train)
weights_samples = posterior.get_samples()["weight"]
bias_samples = posterior.get_samples()["bias"]
weights_samples.shape, bias_samples.shape
Sample: 100%|██████████| 1500/1500 [00:08, 186.67it/s, step size=7.36e-01, acc. prob=0.886]
(torch.Size([1000, 2]), torch.Size([1000]))
# Define the grid over which you want to make predictions
grid_x = np.linspace(X[:, 0].min() - 1, X[:, 0].max() + 1, 100)
grid_y = np.linspace(X[:, 1].min() - 1, X[:, 1].max() + 1, 100)
grid_xx, grid_yy = np.meshgrid(grid_x, grid_y)
grid = np.column_stack([grid_xx.ravel(), grid_yy.ravel()])
grid.shape
(10000, 2)
# Calculate the logits for each sample
logits_samples = torch.matmul(torch.tensor(grid, dtype=torch.float32), weights_samples.t()) + bias_samples
prob_samples = torch.sigmoid(logits_samples)
mean_predictions, std_predictions = prob_samples.mean(1), prob_samples.std(1)
prob_samples.shape, mean_predictions.shape, std_predictions.shape
(torch.Size([10000, 1000]), torch.Size([10000]), torch.Size([10000]))
mean_predictions = mean_predictions.reshape(grid_xx.shape)
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
contour1 = ax[0].contourf(grid_xx, grid_yy, mean_predictions, levels=50, cmap="RdBu", alpha=0.6)
colorbar1 = plt.colorbar(contour1, ax=ax[0], label='Mean Prediction')
contour2 = ax[1].contourf(grid_xx, grid_yy, mean_predictions, levels=50, cmap="RdBu", alpha=0.6)
colorbar2 = plt.colorbar(contour2, ax=ax[1], label='Mean Prediction')
ax[0].scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap="RdBu_r", edgecolor='k', label = 'Training Data')
ax[1].scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap="RdBu_r", edgecolor='k', label = 'Test Data')
ax[0].legend(); ax[1].legend()
fig.suptitle('Logistic Regression Mean Prediction')
Text(0.5, 0.98, 'Logistic Regression Mean Prediction')
std_predictions = std_predictions.reshape(grid_xx.shape)
fig, ax = plt.subplots(1, 2, figsize=(12, 5))
contour1 = ax[0].contourf(grid_xx, grid_yy, std_predictions, levels=50, cmap="viridis", alpha=0.6)
colorbar1 = plt.colorbar(contour1, ax=ax[0], label='Std Prediction')
contour2 = ax[1].contourf(grid_xx, grid_yy, std_predictions, levels=50, cmap="viridis", alpha=0.6)
colorbar2 = plt.colorbar(contour2, ax=ax[1], label='Std Prediction')
ax[0].scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap="RdBu_r", edgecolor='k', label = 'Training Data')
ax[1].scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap="RdBu_r", edgecolor='k', label = 'Test Data')
ax[0].legend(); ax[1].legend()
fig.suptitle('Logistic Regression Std Prediction')
Text(0.5, 0.98, 'Logistic Regression Std Prediction')
Consider the FVC dataset example discussed in the class.
We had only used the train dataset. Now, we want to find out the performance of various models on the test dataset.
Use the given dataset and deduce which model works best in terms of error (MAE) and coverage? The base model is Linear Regression by Sklearn (from sklearn.linear_model import LinearRegression). Plot the trace diagrams and posterior distribution.
Also plot the predictive posterior distribution with 90% confidence interval.
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
# Retina display
%config InlineBackend.figure_format = 'retina'
from jax import random
import warnings
warnings.filterwarnings('ignore')
plt.rcParams['figure.constrained_layout.use'] = True
import seaborn as sns
sns.set_context("notebook")
import numpyro
import numpyro.distributions as dist
import os
import requests
URL = "https://gist.githubusercontent.com/ucals/" + "2cf9d101992cb1b78c2cdd6e3bac6a4b/raw/"+ "43034c39052dcf97d4b894d2ec1bc3f90f3623d9/"+ "osic_pulmonary_fibrosis.csv"
if not os.path.exists("osic_pulmonary_fibrosis.csv"):
response = requests.get(URL)
with open("osic_pulmonary_fibrosis.csv", "wb") as f:
f.write(response.content)
train = pd.read_csv("osic_pulmonary_fibrosis.csv")
train.head()
| Patient | Weeks | FVC | Percent | Age | Sex | SmokingStatus | |
|---|---|---|---|---|---|---|---|
| 0 | ID00007637202177411956430 | -4 | 2315 | 58.253649 | 79 | Male | Ex-smoker |
| 1 | ID00007637202177411956430 | 5 | 2214 | 55.712129 | 79 | Male | Ex-smoker |
| 2 | ID00007637202177411956430 | 7 | 2061 | 51.862104 | 79 | Male | Ex-smoker |
| 3 | ID00007637202177411956430 | 9 | 2144 | 53.950679 | 79 | Male | Ex-smoker |
| 4 | ID00007637202177411956430 | 11 | 2069 | 52.063412 | 79 | Male | Ex-smoker |
train.describe()
| Weeks | FVC | Percent | Age | |
|---|---|---|---|---|
| count | 1549.000000 | 1549.000000 | 1549.000000 | 1549.000000 |
| mean | 31.861846 | 2690.479019 | 77.672654 | 67.188509 |
| std | 23.247550 | 832.770959 | 19.823261 | 7.057395 |
| min | -5.000000 | 827.000000 | 28.877577 | 49.000000 |
| 25% | 12.000000 | 2109.000000 | 62.832700 | 63.000000 |
| 50% | 28.000000 | 2641.000000 | 75.676937 | 68.000000 |
| 75% | 47.000000 | 3171.000000 | 88.621065 | 72.000000 |
| max | 133.000000 | 6399.000000 | 153.145378 | 88.000000 |
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
patient_encoder = LabelEncoder()
train["patient_code"] = patient_encoder.fit_transform(train["Patient"].values)
X = train.drop(columns=["Patient", "FVC", "Percent", "Age", "Sex", "SmokingStatus"])
y = train["FVC"]
x_train, x_test, y_train, y_test = train_test_split(X, y, train_size = 0.8, random_state = 0)
x_train
| Weeks | patient_code | |
|---|---|---|
| 295 | 63 | 33 |
| 516 | 63 | 59 |
| 655 | 46 | 74 |
| 838 | 48 | 94 |
| 452 | 61 | 51 |
| ... | ... | ... |
| 763 | 27 | 86 |
| 835 | 18 | 94 |
| 1216 | 38 | 138 |
| 559 | 73 | 63 |
| 684 | 18 | 77 |
1239 rows × 2 columns
len(x_train), len(x_test), len(y_train), len(y_test)
(1239, 310, 1239, 310)
sample_patient_code_train = x_train["patient_code"].values
sample_patient_code_test = x_test["patient_code"].values
x_train = x_train["Weeks"]
x_test = x_test["Weeks"]
# Converting into numpy arrays
x_train = np.array(x_train); x_test = np.array(x_test)
y_train = np.array(y_train); y_test = np.array(y_test)
x_train, y_train
(array([63, 63, 46, ..., 38, 73, 18]), array([2957, 3327, 2205, ..., 3882, 3907, 3054]))
### Linear regression from scikit-learn
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(x_train.reshape(-1, 1), y_train)
LinearRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
LinearRegression()
lr_sigma = np.std(y_train - lr.predict(x_train.reshape(-1, 1)))
lr.coef_, lr.intercept_, lr_sigma
(array([-1.68051699]), 2747.966054074546, 843.8124655154415)
# Plot the data and the regression line
all_weeks = np.arange(-12, 134, 1)
plt.scatter(x_train, y_train, alpha=0.3, label = "Train data")
plt.scatter(x_test, y_test, color="red", alpha=0.3, label = "Test data")
plt.plot(all_weeks, lr.predict(all_weeks.reshape(-1, 1)), color="black", lw=2, label = "Linear Regression")
plt.fill_between(all_weeks, lr.predict(all_weeks.reshape(-1, 1)) - 1.96*lr_sigma,
lr.predict(all_weeks.reshape(-1, 1)) + 1.96*lr_sigma, alpha=0.2, label = "Std fill")
plt.title("Linear Regression fit")
plt.legend()
plt.xlabel("Weeks")
plt.ylabel("FVC")
Text(0, 0.5, 'FVC')
# Finding the mean absolute error on train set
from sklearn.metrics import mean_absolute_error
maes = {}
maes["LinearRegression on train"] = mean_absolute_error(y_train, lr.predict(x_train.reshape(-1, 1)))
maes["LinearRegression on test"] = mean_absolute_error(y_test, lr.predict(x_test.reshape(-1, 1)))
maes
{'LinearRegression on train': 662.1236659544445,
'LinearRegression on test': 626.3184730275215}
# Finding the 95% coverage on train set
def coverage(y_true, y_pred, sigma):
lower = y_pred - 1.96 * sigma
upper = y_pred + 1.96 * sigma
return np.mean((y_true >= lower) & (y_true <= upper))
coverages = {}
print("Train Coverage: ", coverage(y_train, lr.predict(x_train.reshape(-1, 1)), lr_sigma))
coverages["LinearRegression on test"] = coverage(y_test, lr.predict(x_test.reshape(-1, 1)), lr_sigma)
coverages["LinearRegression on train"] = coverage(y_train, lr.predict(x_train.reshape(-1, 1)), lr_sigma)
coverages
Train Coverage: 0.9548022598870056
{'LinearRegression on test': 0.9709677419354839,
'LinearRegression on train': 0.9548022598870056}
$\alpha \sim \text{Normal}(0, 500)$
$\beta \sim \text{Normal}(0, 500)$
$\sigma \sim \text{HalfNormal}(100)$
for i in range(N_Weeks):
$FVC_i \sim \text{Normal}(\alpha + \beta \cdot Week_i, \sigma)$
def pooled_model(sample_weeks, sample_fvc=None):
α = numpyro.sample("α", dist.Normal(0., 500.))
β = numpyro.sample("β", dist.Normal(0., 500.))
σ = numpyro.sample("σ", dist.HalfNormal(50.))
with numpyro.plate("samples", len(sample_weeks)):
fvc = numpyro.sample("fvc", dist.Normal(α + β * sample_weeks, σ), obs=sample_fvc)
return fvc
sample_weeks = train["Weeks"].values
sample_fvc = train["FVC"].values
from numpyro.infer import MCMC, NUTS, Predictive
nuts_kernel = NUTS(pooled_model)
mcmc = MCMC(nuts_kernel, num_samples=4000, num_warmup=2000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, sample_weeks=x_train, sample_fvc=y_train)
posterior_samples = mcmc.get_samples()
sample: 100%|██████████| 6000/6000 [00:02<00:00, 2206.33it/s, 7 steps of size 4.57e-01. acc. prob=0.91]
import arviz as az
idata = az.from_numpyro(mcmc)
az.plot_trace(idata, compact=True);
# Summary statistics
az.summary(idata, round_to=2)
arviz - WARNING - Shape validation failed: input_shape: (1, 4000), minimum_shape: (chains=2, draws=4)
| mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| α | 2731.78 | 37.24 | 2658.75 | 2798.00 | 0.90 | 0.64 | 1713.98 | 1879.14 | NaN |
| β | -1.36 | 0.93 | -3.01 | 0.44 | 0.02 | 0.02 | 1649.09 | 1842.16 | NaN |
| σ | 773.41 | 12.98 | 750.62 | 799.75 | 0.25 | 0.18 | 2697.11 | 2607.93 | NaN |
Making predictions
predictive = Predictive(pooled_model, mcmc.get_samples())
predictions = predictive(rng_key, all_weeks, None)
print(predictions["fvc"].shape)
pd.DataFrame(predictions["fvc"]).mean().plot(label = "Pooled Model")
plt.plot(all_weeks, lr.predict(all_weeks.reshape(-1, 1)), color="black", lw=2, label = "Linear Regression")
plt.title("Pooled Model Predictions")
plt.xlabel("Weeks"); plt.ylabel("FVC")
plt.legend()
(4000, 146)
<matplotlib.legend.Legend at 0x153679e50>
# Predictive distribution
predictive = Predictive(pooled_model, mcmc.get_samples())
predictions = predictive(rng_key, all_weeks, None)
# Get the mean and standard deviation of the predictions
mu = predictions["fvc"].mean(axis=0)
sigma = predictions["fvc"].std(axis=0)
# Plot the predictions
plt.plot(all_weeks, mu)
plt.fill_between(all_weeks, mu - 1.96*sigma, mu + 1.96*sigma, alpha=0.2, label = "Std fill")
plt.scatter(sample_weeks, sample_fvc, alpha=0.2, label = "Train data")
plt.scatter(x_test, y_test, color="red", alpha=0.3, label = "Test data")
plt.xlabel("Weeks")
plt.title("Predictive distribution of Pooled Model")
plt.ylabel("FVC")
plt.legend()
<matplotlib.legend.Legend at 0x1546d2290>
### Computing Mean Absolute Error and Coverage at 95% confidence interval
preds_pooled = predictive(rng_key, x_train, None)['fvc']
predictions_train_pooled = preds_pooled.mean(axis=0)
std_train_pooled = preds_pooled.std(axis=0)
pred_test_pooled = predictive(rng_key, x_test, None)['fvc']
predictions_test_pooled = pred_test_pooled.mean(axis=0)
std_test_pooled = pred_test_pooled.std(axis=0)
maes["PooledModel on train"] = mean_absolute_error(y_train, predictions_train_pooled)
maes["PooledModel on test"] = mean_absolute_error(y_test, predictions_test_pooled)
maes
{'LinearRegression on train': 662.1236659544445,
'LinearRegression on test': 626.3184730275215,
'PooledModel on train': 661.24235853637,
'PooledModel on test': 626.4491242439516}
### Computing the coverage at 95% confidence interval
coverages["PooledModel on test"] = coverage(y_test, predictions_test_pooled, std_test_pooled).item()
coverages["PooledModel on train"] = coverage(y_train, predictions_train_pooled, std_train_pooled).item()
coverages
{'LinearRegression on test': 0.9709677419354839,
'LinearRegression on train': 0.9548022598870056,
'PooledModel on test': 0.948387086391449,
'PooledModel on train': 0.938660204410553}
### Hierarchical model
def paritally_pooled_model(sample_weeks, sample_patient_code, sample_fvc=None):
μ_α = numpyro.sample("μ_α", dist.Normal(0.0, 500.0))
σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.0))
μ_β = numpyro.sample("μ_β", dist.Normal(0.0, 3.0))
σ_β = numpyro.sample("σ_β", dist.HalfNormal(3.0))
n_patients = len(np.unique(sample_patient_code))
with numpyro.plate("Participants", n_patients):
α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
β = numpyro.sample("β", dist.Normal(μ_β, σ_β))
σ = numpyro.sample("σ", dist.HalfNormal(100.0))
FVC_est = α[sample_patient_code] + β[sample_patient_code] * sample_weeks
with numpyro.plate("data", len(sample_patient_code)):
numpyro.sample("fvc", dist.Normal(FVC_est, σ), obs=sample_fvc)
model_kwargs_train = {"sample_weeks": x_train, "sample_patient_code": sample_patient_code_train, "sample_fvc": y_train}
model_kwargs_test = {"sample_weeks": x_test, "sample_patient_code": sample_patient_code_test, "sample_fvc": y_test}
nuts_final = NUTS(paritally_pooled_model)
mcmc_final = MCMC(nuts_final, num_samples=4000, num_warmup=2000)
rng_key = random.PRNGKey(0)
mcmc_final.run(rng_key, **model_kwargs_train)
sample: 100%|██████████| 6000/6000 [00:15<00:00, 388.28it/s, 63 steps of size 1.73e-02. acc. prob=0.85]
predictive_final = Predictive(paritally_pooled_model, mcmc_final.get_samples())
az.plot_trace(az.from_numpyro(mcmc_final), compact=True);
Getting Mse and Coverage
predictive_hierarchical = Predictive(paritally_pooled_model, mcmc_final.get_samples())
predictions_train_hierarchical = predictive_final(rng_key,
sample_weeks = model_kwargs_train["sample_weeks"],
sample_patient_code = model_kwargs_train["sample_patient_code"])['fvc']
mu_predictions_train_h = predictions_train_hierarchical.mean(axis=0)
std_predictions_train_h = predictions_train_hierarchical.std(axis=0)
maes["Hierarchical on train"] = mean_absolute_error(y_train, mu_predictions_train_h)
coverages["Hierarchical on train"] = coverage(y_train, mu_predictions_train_h, std_predictions_train_h).item()
predictions_test_hierarchical = predictive_final(rng_key,
sample_weeks = model_kwargs_test["sample_weeks"],
sample_patient_code = model_kwargs_test["sample_patient_code"])['fvc']
mu_predictions_test_h = predictions_test_hierarchical.mean(axis=0)
std_predictions_test_h = predictions_test_hierarchical.std(axis=0)
maes["Hierarchical on test"] = mean_absolute_error(y_test, mu_predictions_test_h)
coverages["Hierarchical on test"] = coverage(y_test, mu_predictions_test_h, std_predictions_test_h).item()
maes
{'LinearRegression on train': 662.1236659544445,
'LinearRegression on test': 626.3184730275215,
'PooledModel on train': 661.24235853637,
'PooledModel on test': 626.4491242439516,
'Hierarchical on train': 80.57488639283507,
'Hierarchical on test': 110.49720419606855}
coverages
{'LinearRegression on test': 0.9709677419354839,
'LinearRegression on train': 0.9548022598870056,
'PooledModel on test': 0.948387086391449,
'PooledModel on train': 0.938660204410553,
'Hierarchical on train': 0.9741727113723755,
'Hierarchical on test': 0.9387096762657166}
# Predict for a given patient
def predict_final(patient_code):
predictions = predictive_final(rng_key, all_weeks, patient_code)
mu = predictions["fvc"].mean(axis=0)
sigma = predictions["fvc"].std(axis=0)
return mu, sigma
# Plot the predictions for a given patient
def plot_patient_final(patient_code):
mu, sigma = predict_final(patient_code)
plt.plot(all_weeks, mu)
plt.fill_between(all_weeks, mu - sigma, mu + sigma, alpha=0.1)
id_to_patient = patient_encoder.inverse_transform([patient_code])[0]
patient_weeks = train[train["Patient"] == id_to_patient]["Weeks"]
patient_fvc = train[train["Patient"] == id_to_patient]["FVC"]
plt.scatter(patient_weeks, patient_fvc, alpha=0.5)
plt.xlabel("Weeks")
plt.ylabel("FVC")
plt.title(patient_encoder.inverse_transform([patient_code])[0])
# plot for a given patient
plot_patient_final(np.array([0]))
### Hierarchical model
def Partially_pooled_sigma_model(sample_weeks, sample_patient_code, sample_fvc=None):
μ_α = numpyro.sample("μ_α", dist.Normal(0.0, 500.0))
σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.0))
μ_β = numpyro.sample("μ_β", dist.Normal(0.0, 3.0))
σ_β = numpyro.sample("σ_β", dist.HalfNormal(3.0))
𝛄_σ = numpyro.sample("𝛄_σ", dist.HalfNormal(30.0))
n_patients = len(np.unique(sample_patient_code))
with numpyro.plate("Participants", n_patients):
α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
β = numpyro.sample("β", dist.Normal(μ_β, σ_β))
σ = numpyro.sample("σ", dist.Exponential(𝛄_σ))
FVC_est = α[sample_patient_code] + β[sample_patient_code] * sample_weeks
with numpyro.plate("data", len(sample_patient_code)):
numpyro.sample("fvc", dist.Normal(FVC_est, σ[sample_patient_code]), obs=sample_fvc)
nuts_kernel = NUTS(Partially_pooled_sigma_model)
mcmc_3 = MCMC(nuts_kernel, num_samples=4000, num_warmup=2000)
rng_key = random.PRNGKey(0)
mcmc_3.run(rng_key, **model_kwargs_train)
sample: 100%|██████████| 6000/6000 [00:26<00:00, 224.98it/s, 63 steps of size 5.69e-02. acc. prob=0.87]
az.plot_trace(az.from_numpyro(mcmc_3), compact=True);
predictive = Predictive(Partially_pooled_sigma_model, mcmc_3.get_samples())
predictive_train_3 = predictive(rng_key,sample_weeks = model_kwargs_train["sample_weeks"],
sample_patient_code = model_kwargs_train["sample_patient_code"])['fvc']
mu_predictions_train_3 = predictive_train_3.mean(axis=0)
std_predictions_train_3 = predictive_train_3.std(axis=0)
maes["Hierarchial sigma train"] = mean_absolute_error(y_train, mu_predictions_train_3)
coverages["Hierarchial sigma train"] = coverage(y_train, mu_predictions_train_3, std_predictions_train_3).item()
predictive_test_3 = predictive(rng_key,sample_weeks = model_kwargs_test["sample_weeks"],
sample_patient_code = model_kwargs_test["sample_patient_code"])['fvc']
mu_predictions_test_3 = predictive_test_3.mean(axis=0)
std_predictions_test_3 = predictive_test_3.std(axis=0)
maes["Hierarchial sigma test"] = mean_absolute_error(y_test, mu_predictions_test_3)
coverages["Hierarchial sigma test"] = coverage(y_test, mu_predictions_test_3, std_predictions_test_3).item()
# Predict for a given patient
def predict_final_3(patient_code):
predictions = predictive(rng_key, all_weeks, patient_code)
mu = predictions["fvc"].mean(axis=0)
sigma = predictions["fvc"].std(axis=0)
return mu, sigma
# Plot the predictions for a given patient
def plot_patient_final_3(patient_code):
mu, sigma = predict_final_3(patient_code)
plt.plot(all_weeks, mu)
plt.fill_between(all_weeks, mu - sigma, mu + sigma, alpha=0.1)
id_to_patient = patient_encoder.inverse_transform([patient_code])[0]
#print(id_to_patient[0], patient_code)
#print(patient_code, id_to_patient)
patient_weeks = train[train["Patient"] == id_to_patient]["Weeks"]
patient_fvc = train[train["Patient"] == id_to_patient]["FVC"]
plt.scatter(patient_weeks, patient_fvc, alpha=0.5)
#plt.scatter(sample_weeks[train["patient_code"] == patient_code.item()], fvc[train["patient_code"] == patient_code.item()], alpha=0.5)
plt.xlabel("Weeks")
plt.ylabel("FVC")
plt.title(patient_encoder.inverse_transform([patient_code])[0])
# plot for a given patient
plot_patient_final_3(np.array([32]))
pd.Series(maes)
LinearRegression on train 662.123666 LinearRegression on test 626.318473 PooledModel on train 661.242359 PooledModel on test 626.449124 Hierarchical on train 80.574886 Hierarchical on test 110.497204 Hierarchial sigma train 85.884296 Hierarchial sigma test 111.439091 dtype: float64
pd.Series(coverages)
LinearRegression on test 0.970968 LinearRegression on train 0.954802 PooledModel on test 0.948387 PooledModel on train 0.938660 Hierarchical on train 0.974173 Hierarchical on test 0.938710 Hierarchial sigma train 0.996772 Hierarchial sigma test 0.948387 dtype: float64
Use your version of following models to reproduce figure 4 from the paper referenced at [2].
You can also refer to the notebook in the course.
1) Hypernet [2 marks] 2) Neural Processes [2 marks]
import torch
import os
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
# Remove all the warnings
import warnings
warnings.filterwarnings('ignore')
# Set env CUDA_LAUNCH_BLOCKING=1
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Retina display
%config InlineBackend.figure_format = 'retina'
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
import os
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
])
root = "/content/drive/MyDrive/Dataset/CelebA_Dataset"
celeba_dataset = []
for filename in os.listdir(root):
if filename.endswith('.jpg'):
image_path = os.path.join(root, filename)
image = datasets.folder.default_loader(image_path)
image = transform(image)
celeba_dataset.append(image)
len(celeba_dataset)
761
img0 = transform(datasets.folder.default_loader("/content/drive/MyDrive/Dataset/CelebA_Dataset/000001.jpg"))
print(img0.shape)
plt.imshow(img0.permute(1,2,0))
torch.Size([3, 64, 64])
<matplotlib.image.AxesImage at 0x7c8a5c02a1d0>
from sklearn import preprocessing
def create_scaled_cmap(img, rt = False):
"""
Creates a scaled image and a scaled colormap
"""
img= img
num_channels, height, width = img.shape
# Create a 2D grid of (x,y) coordinates
x_coords = torch.arange(width).repeat(height, 1)
y_coords = torch.arange(height).repeat(width, 1).t()
x_coords = x_coords.reshape(-1)
y_coords = y_coords.reshape(-1)
X = torch.stack([x_coords, y_coords], dim=1).float().to(device)
# Y = img.permute(1, 2, 0).reshape(-1, num_channels).float().to(device)
if rt == True:
Y = img.permute(1, 2, 0).reshape(-1, num_channels).float().to(device)
else:
Y = img.reshape(-1, num_channels).float().to(device)
scaler_X = preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(X.cpu())
scaled_X = torch.tensor(scaler_X.transform(X.cpu())).to(device).float()
return scaled_X, Y, scaler_X
img0_X, img0_Y, scaler_X = create_scaled_cmap(img0)
img0_X.shape, img0_Y.shape
(torch.Size([4096, 2]), torch.Size([4096, 3]))
s = 64
class NN(nn.Module):
def _init_siren(self, activation_scale):
self.fc1.weight.data.uniform_(-1/self.fc1.in_features, 1/self.fc1.in_features)
for layers in [self.fc2, self.fc3, self.fc5]:
layers.weight.data.uniform_(-np.sqrt(6/self.fc2.in_features)/activation_scale,
np.sqrt(6/self.fc2.in_features)/activation_scale)
def __init__(self, activation=torch.sin, n_out=3, activation_scale=1.0):
super().__init__()
self.activation = activation
self.activation_scale = activation_scale
self.fc1 = nn.Linear(2, s)
self.fc2 = nn.Linear(s, s)
self.fc3 = nn.Linear(s, s)
self.fc5 = nn.Linear(s, 3) #gray scale image (1) or RGB (3)
if self.activation == torch.sin:
# init weights and biases for sine activation
self._init_siren(activation_scale=self.activation_scale)
def forward(self, x):
x = self.activation(self.activation_scale*self.fc1(x))
x = self.activation(self.activation_scale*self.fc2(x))
x = self.activation(self.activation_scale*self.fc3(x))
# x = self.activation(self.activation_scale*self.fc4(x))
return self.fc5(x)
This will act as our Target net
# Making singular data for only 1 image
torch.manual_seed(0)
sh_index = torch.randperm(img0_X.shape[0])
# Shuffle the dataset
img0_X_sh = img0_X[sh_index]
img0_Y_sh = img0_Y[sh_index]
sh_index[0:10]
tensor([2732, 1810, 3111, 2738, 155, 2864, 2423, 2918, 2441, 3201])
torch.manual_seed(0)
nns = {}
nns["img0"] = {}
nns["img0"]["relu"] = NN(activation=torch.relu, n_out=3).to(device)
nns["img0"]["sin"] = NN(activation=torch.sin, n_out=3, activation_scale=30.0).to(device)
nns["img0"]["relu"](img0_X_sh).shape, nns["img0"]["sin"](img0_X_sh).shape
(torch.Size([4096, 3]), torch.Size([4096, 3]))
# Training the network to recreate the image
def train_normalnet(net, lr, X, Y, epochs, verbose=True):
"""
net: torch.nn.Module
lr: float
X: torch.Tensor of shape (num_samples, 2)
Y: torch.Tensor of shape (num_samples, 3)
"""
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
for epoch in range(epochs):
optimizer.zero_grad()
outputs = net(X)
loss = criterion(outputs, Y)
loss.backward()
optimizer.step()
if verbose and epoch % 100 == 0:
print(f"Epoch {epoch} loss: {loss.item():.6f}")
return loss.item()
import time
n_iter = 2000
start_time = time.time()
train_normalnet(nns["img0"]["relu"], lr=3e-4, X=img0_X_sh, Y=img0_Y_sh, epochs=n_iter)
end_time = time.time()
print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.330164 Epoch 100 loss: 0.072692 Epoch 200 loss: 0.071209 Epoch 300 loss: 0.070467 Epoch 400 loss: 0.069878 Epoch 500 loss: 0.069198 Epoch 600 loss: 0.068229 Epoch 700 loss: 0.066728 Epoch 800 loss: 0.064489 Epoch 900 loss: 0.061017 Epoch 1000 loss: 0.055909 Epoch 1100 loss: 0.050226 Epoch 1200 loss: 0.045593 Epoch 1300 loss: 0.042071 Epoch 1400 loss: 0.039665 Epoch 1500 loss: 0.037875 Epoch 1600 loss: 0.036552 Epoch 1700 loss: 0.035374 Epoch 1800 loss: 0.034569 Epoch 1900 loss: 0.033748 Training time: 3.97 seconds
output = nns["img0"]["relu"](img0_X)#.detach().cpu().numpy()
print(output.shape)
num_channels, height, width = img0.shape
output = output.reshape(num_channels, height, width)
output = output.permute(1, 2, 0)
fig, ax = plt.subplots(figsize=(4, 3))
ax.imshow(output.detach().cpu())
ax.set_title("Reconstructed Relu Image")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
torch.Size([4096, 3])
Text(0.5, 1.0, 'Reconstructed Relu Image')
import time
n_iter = 2000
start_time = time.time()
train_normalnet(nns["img0"]["sin"], lr=3e-4, X=img0_X_sh, Y=img0_Y_sh, epochs=n_iter)
end_time = time.time()
print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.393856 Epoch 100 loss: 0.012778 Epoch 200 loss: 0.007427 Epoch 300 loss: 0.005406 Epoch 400 loss: 0.004122 Epoch 500 loss: 0.003208 Epoch 600 loss: 0.002511 Epoch 700 loss: 0.001977 Epoch 800 loss: 0.001601 Epoch 900 loss: 0.001339 Epoch 1000 loss: 0.001124 Epoch 1100 loss: 0.000962 Epoch 1200 loss: 0.000838 Epoch 1300 loss: 0.000752 Epoch 1400 loss: 0.000663 Epoch 1500 loss: 0.000605 Epoch 1600 loss: 0.000568 Epoch 1700 loss: 0.000524 Epoch 1800 loss: 0.000518 Epoch 1900 loss: 0.000470 Training time: 3.48 seconds
output = nns["img0"]["sin"](img0_X)#.detach().cpu().numpy()
print(output.shape)
num_channels, height, width = img0.shape
output = output.reshape(num_channels, height, width)
output = output.permute(1, 2, 0)
fig, ax = plt.subplots(figsize=(4, 3))
ax.imshow(output.detach().cpu())
ax.set_title("Reconstructed Image Siren")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
torch.Size([4096, 3])
Text(0.5, 1.0, 'Reconstructed Image Siren')
# Here we are looking at the composition of our target net, The hypernet should be able to
# produce those many features for these parameters
try:
from tabulate import tabulate
except:
%pip install tabulate
from tabulate import tabulate
model = nns["img0"]["sin"]
table_data = []
total_params = 0
start = 0
start_end_mapping = {}
for name, param in model.named_parameters():
param_count = torch.prod(torch.tensor(param.shape)).item()
total_params += param_count
end = total_params
table_data.append([name, param.shape, param_count, start, end])
start_end_mapping[name] = (start, end)
start = end
print(tabulate(table_data, headers=["Layer Name", "Shape", "Parameter Count", "Start Index", "End Index"]))
print(f"Total number of parameters: {total_params}")
Layer Name Shape Parameter Count Start Index End Index ------------ -------------------- ----------------- ------------- ----------- fc1.weight torch.Size([64, 2]) 128 0 128 fc1.bias torch.Size([64]) 64 128 192 fc2.weight torch.Size([64, 64]) 4096 192 4288 fc2.bias torch.Size([64]) 64 4288 4352 fc3.weight torch.Size([64, 64]) 4096 4352 8448 fc3.bias torch.Size([64]) 64 8448 8512 fc5.weight torch.Size([3, 64]) 192 8512 8704 fc5.bias torch.Size([3]) 3 8704 8707 Total number of parameters: 8707
# Hypernet class
total_params=8707
ss=256
class HyperNet(nn.Module):
def __init__(self, num_layers=5, num_neurons=256, activation=torch.sin, n_out=3):
super().__init__()
self.activation = activation
self.n_out = total_params
self.fc1 = nn.Linear(5, ss)
self.fc2 = nn.Linear(ss, ss)
self.fc3 = nn.Linear(ss, total_params)
def forward(self, x):
x = self.activation(self.fc1(x))
x = self.activation(self.fc2(x))
return self.fc3(x)
hp = HyperNet().to(device)
out_hp = hp(torch.rand(10, 5).to(device))
print(out_hp.shape)
weights_flattened = out_hp.mean(dim=0)
print(weights_flattened.shape)
print(hp(torch.rand(2000, 5).to(device)).shape)
torch.Size([10, 8707]) torch.Size([8707]) torch.Size([2000, 8707])
torch.manual_seed(42)
img0_X, img0_Y, scaler_X = create_scaled_cmap(img0, rt = False)
img0_X.shape, img0_Y.shape
img0_hyper_data = []
datano = 1 #instead of 20 images
context_precent = 100 # increased the context size to 10 percent
for i in range(datano):
sh_indexi = torch.randperm(img0_X.shape[0])
cont_img0_X_shi = img0_X[sh_indexi][0:int(len(img0_X)*context_precent/100)]
cont_img0_Y_shi = img0_Y[sh_indexi][0:int(len(img0_X)*context_precent/100)]
context_data = torch.cat((cont_img0_X_shi, cont_img0_Y_shi), dim = 1)
train_img0_X_shi = img0_X
train_img0_Y_shi = img0_Y
datas = [context_data, train_img0_X_shi, train_img0_Y_shi]
img0_hyper_data.append(datas)
img0_hyper_data = []
img0_hyper_data.append([torch.cat((img0_X, img0_Y), dim = 1), img0_X, img0_Y])
img0_hyper_data[0][0].shape, img0_hyper_data[0][1].shape, img0_hyper_data[0][2].shape
(torch.Size([4096, 5]), torch.Size([4096, 2]), torch.Size([4096, 3]))
params2 = hp(img0_hyper_data[0][0]).mean(dim=0)
params2.shape, img0_hyper_data[0][0].shape
(torch.Size([8707]), torch.Size([4096, 5]))
model = NN(activation=torch.sin, n_out=3, activation_scale=30.0).to(device)
model
NN( (fc1): Linear(in_features=2, out_features=64, bias=True) (fc2): Linear(in_features=64, out_features=64, bias=True) (fc3): Linear(in_features=64, out_features=64, bias=True) (fc5): Linear(in_features=64, out_features=3, bias=True) )
# Installing the astra library
! pip install astra-lib
Collecting astra-lib
Downloading astra-lib-0.0.2.tar.gz (136 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 136.6/136.6 kB 3.3 MB/s eta 0:00:00
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from astra-lib) (1.23.5)
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from astra-lib) (1.5.3)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from astra-lib) (3.7.1)
Requirement already satisfied: xarray in /usr/local/lib/python3.10/dist-packages (from astra-lib) (2023.7.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from astra-lib) (4.66.1)
Collecting optree (from astra-lib)
Downloading optree-0.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (286 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 286.8/286.8 kB 15.3 MB/s eta 0:00:00
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (4.44.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (23.2)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (9.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (3.1.1)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->astra-lib) (2.8.2)
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from optree->astra-lib) (4.5.0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->astra-lib) (2023.3.post1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->astra-lib) (1.16.0)
Building wheels for collected packages: astra-lib
Building wheel for astra-lib (pyproject.toml) ... done
Created wheel for astra-lib: filename=astra_lib-0.0.2-py3-none-any.whl size=21845 sha256=21853f35730a094084175db7468b414103a824ca437ec06c5b23383555786d99
Stored in directory: /root/.cache/pip/wheels/db/a6/8d/73931c696ff5c17a3364e2962cf8680790ae07a3e8aa55587b
Successfully built astra-lib
Installing collected packages: optree, astra-lib
Successfully installed astra-lib-0.0.2 optree-0.10.0
import astra
from astra.torch.utils import ravel_pytree
flat_weights, unravel_fn = ravel_pytree(dict(model.named_parameters()))
print(flat_weights.shape)
torch.Size([8707])
unravel_fn
<function astra.torch.utils.ravel_pytree.<locals>.unravel_function(flat_params)>
# Training the network to recreate the image
torch.manual_seed(42)
# targetnet = NN(activation=torch.sin, n_out=3, activation_scale=30.0).to(device) # Target Net
# new_dict = targetnet.state_dict()
def train_hypernet(hypernet, target_net, lr, hyper_data, epochs, verbose=True):
"""
net: torch.nn.Module
lr: float
X: torch.Tensor of shape (num_samples, 2)
Y: torch.Tensor of shape (num_samples, 3)
"""
loss_list = []
datano = len(hyper_data)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(hypernet.parameters(), lr=lr)
for epoch in range(epochs):
running_loss = 0.0
for i in range(datano):
context_data, train_img_X_shi, train_img_Y_shi = hyper_data[i]
optimizer.zero_grad()
output = hypernet(context_data)
params = output.mean(dim=0)
flat_weights, unravel_fn = ravel_pytree(dict(target_net.named_parameters()))
new_dict = unravel_fn(params)
# new_dict.update(unravel_fn(params))
outputs = torch.func.functional_call(target_net, new_dict, train_img_X_shi)
loss = criterion(outputs, train_img_Y_shi)
loss.backward()
optimizer.step()
running_loss += loss.item()
loss_list.append(running_loss)
if verbose and epoch % 50 == 0:
# Values are not chaninging ? why?
# print(hypernet.state_dict()["fc1.weight"][0:2, 0:2])
print(f"Epoch {epoch} loss: {running_loss/datano:.6f}")
return loss_list
import time
torch.manual_seed(42)
start_time = time.time()
hp1 = HyperNet(activation = nn.ReLU() ).to(device)
targetnet1 = NN().to(device)
loss_list = train_hypernet(hp1, targetnet1, lr = 3e-4, hyper_data = img0_hyper_data, epochs = 15000, verbose=True)
# reduced learning rate
end_time = time.time()
print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.403650 Epoch 50 loss: 0.071302 Epoch 100 loss: 0.071194 Epoch 150 loss: 0.071189 Epoch 200 loss: 0.071183 Epoch 250 loss: 0.071176 Epoch 300 loss: 0.071166 Epoch 350 loss: 0.071154 Epoch 400 loss: 0.071146 Epoch 450 loss: 0.071144 Epoch 500 loss: 0.071142 Epoch 550 loss: 0.071140 Epoch 600 loss: 0.071138 Epoch 650 loss: 0.071136 Epoch 700 loss: 0.071135 Epoch 750 loss: 0.071133 Epoch 800 loss: 0.071132 Epoch 850 loss: 0.071130 Epoch 900 loss: 0.071129 Epoch 950 loss: 0.071127 Epoch 1000 loss: 0.071126 Epoch 1050 loss: 0.071128 Epoch 1100 loss: 0.071125 Epoch 1150 loss: 0.071124 Epoch 1200 loss: 0.071129 Epoch 1250 loss: 0.071126 Epoch 1300 loss: 0.071125 Epoch 1350 loss: 0.071124 Epoch 1400 loss: 0.071123 Epoch 1450 loss: 0.071122 Epoch 1500 loss: 0.071511 Epoch 1550 loss: 0.071131 Epoch 1600 loss: 0.071128 Epoch 1650 loss: 0.071126 Epoch 1700 loss: 0.071125 Epoch 1750 loss: 0.071124 Epoch 1800 loss: 0.071123 Epoch 1850 loss: 0.071122 Epoch 1900 loss: 0.071121 Epoch 1950 loss: 0.071120 Epoch 2000 loss: 0.071119 Epoch 2050 loss: 0.071293 Epoch 2100 loss: 0.071133 Epoch 2150 loss: 0.071130 Epoch 2200 loss: 0.071128 Epoch 2250 loss: 0.071126 Epoch 2300 loss: 0.071125 Epoch 2350 loss: 0.071124 Epoch 2400 loss: 0.071123 Epoch 2450 loss: 0.071122 Epoch 2500 loss: 0.071121 Epoch 2550 loss: 0.071120 Epoch 2600 loss: 0.071119 Epoch 2650 loss: 0.071118 Epoch 2700 loss: 0.071116 Epoch 2750 loss: 0.071143 Epoch 2800 loss: 0.071109 Epoch 2850 loss: 0.070792 Epoch 2900 loss: 0.072253 Epoch 2950 loss: 0.070422 Epoch 3000 loss: 0.070312 Epoch 3050 loss: 0.072265 Epoch 3100 loss: 0.070852 Epoch 3150 loss: 0.070409 Epoch 3200 loss: 0.070321 Epoch 3250 loss: 0.070825 Epoch 3300 loss: 0.070360 Epoch 3350 loss: 0.070298 Epoch 3400 loss: 0.070170 Epoch 3450 loss: 0.069760 Epoch 3500 loss: 0.069480 Epoch 3550 loss: 0.069419 Epoch 3600 loss: 0.069325 Epoch 3650 loss: 0.069534 Epoch 3700 loss: 0.068746 Epoch 3750 loss: 0.068677 Epoch 3800 loss: 0.068613 Epoch 3850 loss: 0.068559 Epoch 3900 loss: 0.069640 Epoch 3950 loss: 0.068416 Epoch 4000 loss: 0.068115 Epoch 4050 loss: 0.067895 Epoch 4100 loss: 0.067805 Epoch 4150 loss: 0.066609 Epoch 4200 loss: 0.057200 Epoch 4250 loss: 0.053656 Epoch 4300 loss: 0.051547 Epoch 4350 loss: 0.049484 Epoch 4400 loss: 0.047249 Epoch 4450 loss: 0.039190 Epoch 4500 loss: 0.035528 Epoch 4550 loss: 0.035988 Epoch 4600 loss: 0.031441 Epoch 4650 loss: 0.030642 Epoch 4700 loss: 0.029882 Epoch 4750 loss: 0.028969 Epoch 4800 loss: 0.028517 Epoch 4850 loss: 0.027376 Epoch 4900 loss: 0.028707 Epoch 4950 loss: 0.024760 Epoch 5000 loss: 0.024197 Epoch 5050 loss: 0.023947 Epoch 5100 loss: 0.024595 Epoch 5150 loss: 0.022410 Epoch 5200 loss: 0.022466 Epoch 5250 loss: 0.021812 Epoch 5300 loss: 0.020312 Epoch 5350 loss: 0.021673 Epoch 5400 loss: 0.019209 Epoch 5450 loss: 0.018692 Epoch 5500 loss: 0.018397 Epoch 5550 loss: 0.018506 Epoch 5600 loss: 0.017837 Epoch 5650 loss: 0.017525 Epoch 5700 loss: 0.017342 Epoch 5750 loss: 0.017501 Epoch 5800 loss: 0.017087 Epoch 5850 loss: 0.016972 Epoch 5900 loss: 0.016583 Epoch 5950 loss: 0.016422 Epoch 6000 loss: 0.016368 Epoch 6050 loss: 0.016281 Epoch 6100 loss: 0.016753 Epoch 6150 loss: 0.016152 Epoch 6200 loss: 0.016031 Epoch 6250 loss: 0.017693 Epoch 6300 loss: 0.016301 Epoch 6350 loss: 0.016454 Epoch 6400 loss: 0.017366 Epoch 6450 loss: 0.015956 Epoch 6500 loss: 0.015282 Epoch 6550 loss: 0.015224 Epoch 6600 loss: 0.015460 Epoch 6650 loss: 0.015093 Epoch 6700 loss: 0.015143 Epoch 6750 loss: 0.015584 Epoch 6800 loss: 0.014723 Epoch 6850 loss: 0.014729 Epoch 6900 loss: 0.014920 Epoch 6950 loss: 0.014756 Epoch 7000 loss: 0.014658 Epoch 7050 loss: 0.014436 Epoch 7100 loss: 0.015304 Epoch 7150 loss: 0.014708 Epoch 7200 loss: 0.014184 Epoch 7250 loss: 0.014117 Epoch 7300 loss: 0.014174 Epoch 7350 loss: 0.013970 Epoch 7400 loss: 0.014577 Epoch 7450 loss: 0.014161 Epoch 7500 loss: 0.013952 Epoch 7550 loss: 0.013838 Epoch 7600 loss: 0.013753 Epoch 7650 loss: 0.013718 Epoch 7700 loss: 0.014550 Epoch 7750 loss: 0.014340 Epoch 7800 loss: 0.013824 Epoch 7850 loss: 0.013968 Epoch 7900 loss: 0.015152 Epoch 7950 loss: 0.013289 Epoch 8000 loss: 0.013215 Epoch 8050 loss: 0.013265 Epoch 8100 loss: 0.013107 Epoch 8150 loss: 0.013669 Epoch 8200 loss: 0.013174 Epoch 8250 loss: 0.012806 Epoch 8300 loss: 0.012764 Epoch 8350 loss: 0.012953 Epoch 8400 loss: 0.013129 Epoch 8450 loss: 0.012777 Epoch 8500 loss: 0.012705 Epoch 8550 loss: 0.014406 Epoch 8600 loss: 0.013657 Epoch 8650 loss: 0.012415 Epoch 8700 loss: 0.012482 Epoch 8750 loss: 0.012352 Epoch 8800 loss: 0.012332 Epoch 8850 loss: 0.013093 Epoch 8900 loss: 0.012409 Epoch 8950 loss: 0.012751 Epoch 9000 loss: 0.012107 Epoch 9050 loss: 0.011987 Epoch 9100 loss: 0.011990 Epoch 9150 loss: 0.011914 Epoch 9200 loss: 0.011947 Epoch 9250 loss: 0.011958 Epoch 9300 loss: 0.011884 Epoch 9350 loss: 0.011888 Epoch 9400 loss: 0.011983 Epoch 9450 loss: 0.011705 Epoch 9500 loss: 0.011694 Epoch 9550 loss: 0.012067 Epoch 9600 loss: 0.011904 Epoch 9650 loss: 0.011582 Epoch 9700 loss: 0.012040 Epoch 9750 loss: 0.011692 Epoch 9800 loss: 0.011438 Epoch 9850 loss: 0.011446 Epoch 9900 loss: 0.011637 Epoch 9950 loss: 0.011481 Epoch 10000 loss: 0.011284 Epoch 10050 loss: 0.012001 Epoch 10100 loss: 0.011459 Epoch 10150 loss: 0.011424 Epoch 10200 loss: 0.011292 Epoch 10250 loss: 0.011367 Epoch 10300 loss: 0.011308 Epoch 10350 loss: 0.011074 Epoch 10400 loss: 0.011024 Epoch 10450 loss: 0.011534 Epoch 10500 loss: 0.011289 Epoch 10550 loss: 0.011159 Epoch 10600 loss: 0.011577 Epoch 10650 loss: 0.011358 Epoch 10700 loss: 0.010867 Epoch 10750 loss: 0.011331 Epoch 10800 loss: 0.010989 Epoch 10850 loss: 0.010952 Epoch 10900 loss: 0.010838 Epoch 10950 loss: 0.010806 Epoch 11000 loss: 0.011251 Epoch 11050 loss: 0.010741 Epoch 11100 loss: 0.010635 Epoch 11150 loss: 0.010690 Epoch 11200 loss: 0.010726 Epoch 11250 loss: 0.011051 Epoch 11300 loss: 0.011013 Epoch 11350 loss: 0.011148 Epoch 11400 loss: 0.010769 Epoch 11450 loss: 0.010810 Epoch 11500 loss: 0.010482 Epoch 11550 loss: 0.010950 Epoch 11600 loss: 0.010358 Epoch 11650 loss: 0.010617 Epoch 11700 loss: 0.010756 Epoch 11750 loss: 0.010459 Epoch 11800 loss: 0.010271 Epoch 11850 loss: 0.010931 Epoch 11900 loss: 0.010484 Epoch 11950 loss: 0.010253 Epoch 12000 loss: 0.010425 Epoch 12050 loss: 0.010247 Epoch 12100 loss: 0.010368 Epoch 12150 loss: 0.010326 Epoch 12200 loss: 0.010047 Epoch 12250 loss: 0.010260 Epoch 12300 loss: 0.010120 Epoch 12350 loss: 0.010563 Epoch 12400 loss: 0.010085 Epoch 12450 loss: 0.009982 Epoch 12500 loss: 0.010174 Epoch 12550 loss: 0.010021 Epoch 12600 loss: 0.009996 Epoch 12650 loss: 0.010109 Epoch 12700 loss: 0.009951 Epoch 12750 loss: 0.009865 Epoch 12800 loss: 0.010103 Epoch 12850 loss: 0.009926 Epoch 12900 loss: 0.010144 Epoch 12950 loss: 0.009936 Epoch 13000 loss: 0.010185 Epoch 13050 loss: 0.010106 Epoch 13100 loss: 0.010005 Epoch 13150 loss: 0.009894 Epoch 13200 loss: 0.009666 Epoch 13250 loss: 0.009817 Epoch 13300 loss: 0.009859 Epoch 13350 loss: 0.009615 Epoch 13400 loss: 0.009594 Epoch 13450 loss: 0.009915 Epoch 13500 loss: 0.009762 Epoch 13550 loss: 0.009700 Epoch 13600 loss: 0.009972 Epoch 13650 loss: 0.010076 Epoch 13700 loss: 0.009769 Epoch 13750 loss: 0.009585 Epoch 13800 loss: 0.009591 Epoch 13850 loss: 0.009793 Epoch 13900 loss: 0.009485 Epoch 13950 loss: 0.009821 Epoch 14000 loss: 0.009819 Epoch 14050 loss: 0.009761 Epoch 14100 loss: 0.009630 Epoch 14150 loss: 0.009575 Epoch 14200 loss: 0.009442 Epoch 14250 loss: 0.010140 Epoch 14300 loss: 0.009336 Epoch 14350 loss: 0.009232 Epoch 14400 loss: 0.010051 Epoch 14450 loss: 0.009685 Epoch 14500 loss: 0.009288 Epoch 14550 loss: 0.009366 Epoch 14600 loss: 0.009287 Epoch 14650 loss: 0.009389 Epoch 14700 loss: 0.009396 Epoch 14750 loss: 0.009255 Epoch 14800 loss: 0.009340 Epoch 14850 loss: 0.009224 Epoch 14900 loss: 0.009066 Epoch 14950 loss: 0.009403 Training time: 286.73 seconds
loss_list2 = train_hypernet(hp1, targetnet1, lr = 3e-4, hyper_data = img0_hyper_data, epochs = 10000, verbose=True)
Epoch 0 loss: 0.009098 Epoch 50 loss: 0.013128 Epoch 100 loss: 0.009980 Epoch 150 loss: 0.009274 Epoch 200 loss: 0.009523 Epoch 250 loss: 0.009007 Epoch 300 loss: 0.008988 Epoch 350 loss: 0.009015 Epoch 400 loss: 0.009092 Epoch 450 loss: 0.009587 Epoch 500 loss: 0.008949 Epoch 550 loss: 0.009135 Epoch 600 loss: 0.009011 Epoch 650 loss: 0.008903 Epoch 700 loss: 0.008935 Epoch 750 loss: 0.009252 Epoch 800 loss: 0.009137 Epoch 850 loss: 0.008866 Epoch 900 loss: 0.009325 Epoch 950 loss: 0.009051 Epoch 1000 loss: 0.008735 Epoch 1050 loss: 0.008841 Epoch 1100 loss: 0.009356 Epoch 1150 loss: 0.008910 Epoch 1200 loss: 0.008656 Epoch 1250 loss: 0.008877 Epoch 1300 loss: 0.008681 Epoch 1350 loss: 0.008711 Epoch 1400 loss: 0.009730 Epoch 1450 loss: 0.009411 Epoch 1500 loss: 0.008878 Epoch 1550 loss: 0.008693 Epoch 1600 loss: 0.009151 Epoch 1650 loss: 0.008871 Epoch 1700 loss: 0.008522 Epoch 1750 loss: 0.008817 Epoch 1800 loss: 0.008764 Epoch 1850 loss: 0.008510 Epoch 1900 loss: 0.008537 Epoch 1950 loss: 0.008460 Epoch 2000 loss: 0.008674 Epoch 2050 loss: 0.009251 Epoch 2100 loss: 0.008363 Epoch 2150 loss: 0.008522 Epoch 2200 loss: 0.008513 Epoch 2250 loss: 0.008552 Epoch 2300 loss: 0.008843 Epoch 2350 loss: 0.008299 Epoch 2400 loss: 0.008840 Epoch 2450 loss: 0.008896 Epoch 2500 loss: 0.008771 Epoch 2550 loss: 0.008288 Epoch 2600 loss: 0.008335 Epoch 2650 loss: 0.008693 Epoch 2700 loss: 0.008339 Epoch 2750 loss: 0.008195 Epoch 2800 loss: 0.008204 Epoch 2850 loss: 0.008378 Epoch 2900 loss: 0.008293 Epoch 2950 loss: 0.008314 Epoch 3000 loss: 0.008148 Epoch 3050 loss: 0.008174 Epoch 3100 loss: 0.008599 Epoch 3150 loss: 0.008363 Epoch 3200 loss: 0.008353 Epoch 3250 loss: 0.008135 Epoch 3300 loss: 0.008423 Epoch 3350 loss: 0.008360 Epoch 3400 loss: 0.008199 Epoch 3450 loss: 0.008243 Epoch 3500 loss: 0.008171 Epoch 3550 loss: 0.008225 Epoch 3600 loss: 0.008203 Epoch 3650 loss: 0.008148 Epoch 3700 loss: 0.008174 Epoch 3750 loss: 0.008039 Epoch 3800 loss: 0.008076 Epoch 3850 loss: 0.008205 Epoch 3900 loss: 0.008367 Epoch 3950 loss: 0.008299 Epoch 4000 loss: 0.008120 Epoch 4050 loss: 0.008021 Epoch 4100 loss: 0.008654 Epoch 4150 loss: 0.008405 Epoch 4200 loss: 0.008289 Epoch 4250 loss: 0.008199 Epoch 4300 loss: 0.008161 Epoch 4350 loss: 0.008186 Epoch 4400 loss: 0.007934 Epoch 4450 loss: 0.008202 Epoch 4500 loss: 0.008064 Epoch 4550 loss: 0.007923 Epoch 4600 loss: 0.007967 Epoch 4650 loss: 0.008116 Epoch 4700 loss: 0.007972 Epoch 4750 loss: 0.008034 Epoch 4800 loss: 0.008200 Epoch 4850 loss: 0.008282 Epoch 4900 loss: 0.007968 Epoch 4950 loss: 0.007777 Epoch 5000 loss: 0.007802 Epoch 5050 loss: 0.007816 Epoch 5100 loss: 0.008060 Epoch 5150 loss: 0.008199 Epoch 5200 loss: 0.007822 Epoch 5250 loss: 0.007950 Epoch 5300 loss: 0.008030 Epoch 5350 loss: 0.007862 Epoch 5400 loss: 0.007946 Epoch 5450 loss: 0.007973 Epoch 5500 loss: 0.007721 Epoch 5550 loss: 0.007928 Epoch 5600 loss: 0.007797 Epoch 5650 loss: 0.007869 Epoch 5700 loss: 0.008315 Epoch 5750 loss: 0.007892 Epoch 5800 loss: 0.007714 Epoch 5850 loss: 0.008180 Epoch 5900 loss: 0.007852 Epoch 5950 loss: 0.007752 Epoch 6000 loss: 0.008154 Epoch 6050 loss: 0.007822 Epoch 6100 loss: 0.007803 Epoch 6150 loss: 0.007877 Epoch 6200 loss: 0.007957 Epoch 6250 loss: 0.008377 Epoch 6300 loss: 0.007943 Epoch 6350 loss: 0.007683 Epoch 6400 loss: 0.007771 Epoch 6450 loss: 0.007806 Epoch 6500 loss: 0.007923 Epoch 6550 loss: 0.007985 Epoch 6600 loss: 0.007675 Epoch 6650 loss: 0.007631 Epoch 6700 loss: 0.007656 Epoch 6750 loss: 0.008071 Epoch 6800 loss: 0.007907 Epoch 6850 loss: 0.007681 Epoch 6900 loss: 0.008560 Epoch 6950 loss: 0.007597 Epoch 7000 loss: 0.007686 Epoch 7050 loss: 0.007675 Epoch 7100 loss: 0.007700 Epoch 7150 loss: 0.007716 Epoch 7200 loss: 0.008152 Epoch 7250 loss: 0.007822 Epoch 7300 loss: 0.007648 Epoch 7350 loss: 0.007931 Epoch 7400 loss: 0.007609 Epoch 7450 loss: 0.007668 Epoch 7500 loss: 0.007560 Epoch 7550 loss: 0.007940 Epoch 7600 loss: 0.007523 Epoch 7650 loss: 0.007659 Epoch 7700 loss: 0.007517 Epoch 7750 loss: 0.007579 Epoch 7800 loss: 0.008199 Epoch 7850 loss: 0.007650 Epoch 7900 loss: 0.007635 Epoch 7950 loss: 0.007524 Epoch 8000 loss: 0.007621 Epoch 8050 loss: 0.007606 Epoch 8100 loss: 0.007712 Epoch 8150 loss: 0.007747 Epoch 8200 loss: 0.007711 Epoch 8250 loss: 0.007574 Epoch 8300 loss: 0.009379 Epoch 8350 loss: 0.007912 Epoch 8400 loss: 0.007517 Epoch 8450 loss: 0.007770 Epoch 8500 loss: 0.007442 Epoch 8550 loss: 0.007595 Epoch 8600 loss: 0.007417 Epoch 8650 loss: 0.007445 Epoch 8700 loss: 0.007657 Epoch 8750 loss: 0.007604 Epoch 8800 loss: 0.007451 Epoch 8850 loss: 0.007435 Epoch 8900 loss: 0.007390 Epoch 8950 loss: 0.008437 Epoch 9000 loss: 0.007788 Epoch 9050 loss: 0.007606 Epoch 9100 loss: 0.007656 Epoch 9150 loss: 0.007530 Epoch 9200 loss: 0.007515 Epoch 9250 loss: 0.007492 Epoch 9300 loss: 0.007468 Epoch 9350 loss: 0.007586 Epoch 9400 loss: 0.007702 Epoch 9450 loss: 0.007509 Epoch 9500 loss: 0.007505 Epoch 9550 loss: 0.007531 Epoch 9600 loss: 0.007483 Epoch 9650 loss: 0.007448 Epoch 9700 loss: 0.007610 Epoch 9750 loss: 0.007565 Epoch 9800 loss: 0.007536 Epoch 9850 loss: 0.007473 Epoch 9900 loss: 0.007490 Epoch 9950 loss: 0.008649
loss_list.extend(loss_list2)
plt.plot(loss_list[:25000])
plt.xlabel("Epochs")
plt.ylabel("loss")
Text(0, 0.5, 'loss')
torch.save(hp1.state_dict(), "img0_hypernet.pt")
def plot_reconstructed_and_original_image(original_img, hypernet, targetnet, X, context, title=""):
"""
original_img = Original image
Hypernet = hypernet
targetnet = targetnet
X = the full scaled image X
context = (n,5) shaped context
"""
# num_channels, height, width = original_img.shape
num_channels, height, width = original_img.shape
with torch.no_grad():
params = hypernet(context).mean(dim=0)
flat_weights,unravel_fn= ravel_pytree(dict(targetnet.named_parameters()))
parameter_dictionary = unravel_fn(params)
outputs = torch.func.functional_call(targetnet, parameter_dictionary, X)
print(output.shape)
outputs = outputs.reshape(num_channels, height, width)
outputs = outputs.permute(1, 2, 0)
fig = plt.figure(figsize=(8, 6))
gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1])
ax0 = plt.subplot(gs[0])
ax1 = plt.subplot(gs[1])
ax0.imshow(output.detach().cpu().numpy())
ax0.set_title("Reconstructed Image")
ax1.imshow(original_img.cpu().permute(1, 2, 0))
ax1.set_title("Original Image")
for a in [ax0, ax1]:
a.axis("off")
fig.suptitle(title, y=0.9)
plt.tight_layout()
context = torch.cat((img0_X, img0_Y), dim = 1)
plot_reconstructed_and_original_image(img0, hp1, targetnet1, img0_X, context, title="Hypernet")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
torch.Size([64, 64, 3])
torch.manual_seed(42)
img0_X, img0_Y, scaler_X = create_scaled_cmap(img0, rt = False)
img0_X.shape, img0_Y.shape
img0_hyper_data2 = []
datano = 5 #instead of 20 images
context_precent_start = 10 # increased the context size to 10 percent
context_percent_end = 80
for i in range(datano):
contextp = context_precent_start + (i/datano)*(context_percent_end-context_precent_start)
sh_indexi = torch.randperm(img0_X.shape[0])
cont_img0_X_shi = img0_X[sh_indexi][0:int(len(img0_X)*contextp/100)]
cont_img0_Y_shi = img0_Y[sh_indexi][0:int(len(img0_X)*contextp/100)]
context_data = torch.cat((cont_img0_X_shi, cont_img0_Y_shi), dim = 1)
train_img0_X_shi = img0_X
train_img0_Y_shi = img0_Y
datas = [context_data, train_img0_X_shi, train_img0_Y_shi]
img0_hyper_data2.append(datas)
img0_hyper_data2[1][0].shape, img0_hyper_data2[1][1].shape, img0_hyper_data2[1][1].shape
(torch.Size([983, 5]), torch.Size([4096, 2]), torch.Size([4096, 2]))
import time
torch.manual_seed(42)
start_time = time.time()
hp_img2 = HyperNet(activation = nn.ReLU() ).to(device)
targetnet_img2 = NN().to(device)
loss_list_img2 = train_hypernet(hp_img2, targetnet_img2, lr = 3e-4, hyper_data = img0_hyper_data2, epochs = 10000, verbose=True)
# reduced learning rate
end_time = time.time()
print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.192977 Epoch 50 loss: 0.071178 Epoch 100 loss: 0.071160 Epoch 150 loss: 0.071141 Epoch 200 loss: 0.071152 Epoch 250 loss: 0.071129 Epoch 300 loss: 0.071131 Epoch 350 loss: 0.071149 Epoch 400 loss: 0.071139 Epoch 450 loss: 0.071147 Epoch 500 loss: 0.071271 Epoch 550 loss: 0.071155 Epoch 600 loss: 0.069042 Epoch 650 loss: 0.068657 Epoch 700 loss: 0.055331 Epoch 750 loss: 0.036802 Epoch 800 loss: 0.034901 Epoch 850 loss: 0.028884 Epoch 900 loss: 0.026838 Epoch 950 loss: 0.025977 Epoch 1000 loss: 0.026541 Epoch 1050 loss: 0.022476 Epoch 1100 loss: 0.021346 Epoch 1150 loss: 0.019726 Epoch 1200 loss: 0.021250 Epoch 1250 loss: 0.017727 Epoch 1300 loss: 0.017597 Epoch 1350 loss: 0.017356 Epoch 1400 loss: 0.017615 Epoch 1450 loss: 0.017500 Epoch 1500 loss: 0.015576 Epoch 1550 loss: 0.015585 Epoch 1600 loss: 0.014925 Epoch 1650 loss: 0.014785 Epoch 1700 loss: 0.014516 Epoch 1750 loss: 0.014382 Epoch 1800 loss: 0.014247 Epoch 1850 loss: 0.013916 Epoch 1900 loss: 0.013573 Epoch 1950 loss: 0.013585 Epoch 2000 loss: 0.013410 Epoch 2050 loss: 0.012568 Epoch 2100 loss: 0.012534 Epoch 2150 loss: 0.012586 Epoch 2200 loss: 0.012624 Epoch 2250 loss: 0.012460 Epoch 2300 loss: 0.012309 Epoch 2350 loss: 0.011867 Epoch 2400 loss: 0.012249 Epoch 2450 loss: 0.012101 Epoch 2500 loss: 0.011160 Epoch 2550 loss: 0.011214 Epoch 2600 loss: 0.011539 Epoch 2650 loss: 0.011330 Epoch 2700 loss: 0.011088 Epoch 2750 loss: 0.010844 Epoch 2800 loss: 0.010697 Epoch 2850 loss: 0.010616 Epoch 2900 loss: 0.010692 Epoch 2950 loss: 0.010700 Epoch 3000 loss: 0.010689 Epoch 3050 loss: 0.010260 Epoch 3100 loss: 0.010396 Epoch 3150 loss: 0.009796 Epoch 3200 loss: 0.009536 Epoch 3250 loss: 0.010086 Epoch 3300 loss: 0.009856 Epoch 3350 loss: 0.009641 Epoch 3400 loss: 0.009877 Epoch 3450 loss: 0.009714 Epoch 3500 loss: 0.009413 Epoch 3550 loss: 0.009772 Epoch 3600 loss: 0.009389 Epoch 3650 loss: 0.009169 Epoch 3700 loss: 0.009201 Epoch 3750 loss: 0.009153 Epoch 3800 loss: 0.009148 Epoch 3850 loss: 0.008968 Epoch 3900 loss: 0.009011 Epoch 3950 loss: 0.009109 Epoch 4000 loss: 0.008994 Epoch 4050 loss: 0.008814 Epoch 4100 loss: 0.009027 Epoch 4150 loss: 0.008648 Epoch 4200 loss: 0.009128 Epoch 4250 loss: 0.008316 Epoch 4300 loss: 0.008570 Epoch 4350 loss: 0.010499 Epoch 4400 loss: 0.008047 Epoch 4450 loss: 0.008221 Epoch 4500 loss: 0.008445 Epoch 4550 loss: 0.008254 Epoch 4600 loss: 0.008390 Epoch 4650 loss: 0.008111 Epoch 4700 loss: 0.009442 Epoch 4750 loss: 0.008505 Epoch 4800 loss: 0.008337 Epoch 4850 loss: 0.008782 Epoch 4900 loss: 0.008167 Epoch 4950 loss: 0.008045 Epoch 5000 loss: 0.008011 Epoch 5050 loss: 0.008032 Epoch 5100 loss: 0.008029 Epoch 5150 loss: 0.007824 Epoch 5200 loss: 0.007987 Epoch 5250 loss: 0.008000 Epoch 5300 loss: 0.008100 Epoch 5350 loss: 0.007967 Epoch 5400 loss: 0.007496 Epoch 5450 loss: 0.007570 Epoch 5500 loss: 0.007435 Epoch 5550 loss: 0.007867 Epoch 5600 loss: 0.007734 Epoch 5650 loss: 0.008091 Epoch 5700 loss: 0.007538 Epoch 5750 loss: 0.007734 Epoch 5800 loss: 0.007423 Epoch 5850 loss: 0.007576 Epoch 5900 loss: 0.007737 Epoch 5950 loss: 0.007457 Epoch 6000 loss: 0.007553 Epoch 6050 loss: 0.007804 Epoch 6100 loss: 0.007460 Epoch 6150 loss: 0.007287 Epoch 6200 loss: 0.007043 Epoch 6250 loss: 0.007309 Epoch 6300 loss: 0.007235 Epoch 6350 loss: 0.007657 Epoch 6400 loss: 0.007053 Epoch 6450 loss: 0.007620 Epoch 6500 loss: 0.006970 Epoch 6550 loss: 0.007006 Epoch 6600 loss: 0.007215 Epoch 6650 loss: 0.007075 Epoch 6700 loss: 0.007141 Epoch 6750 loss: 0.006854 Epoch 6800 loss: 0.007412 Epoch 6850 loss: 0.007005 Epoch 6900 loss: 0.007194 Epoch 6950 loss: 0.006748 Epoch 7000 loss: 0.006982 Epoch 7050 loss: 0.006718 Epoch 7100 loss: 0.006766 Epoch 7150 loss: 0.006730 Epoch 7200 loss: 0.006824 Epoch 7250 loss: 0.007010 Epoch 7300 loss: 0.006894 Epoch 7350 loss: 0.006803 Epoch 7400 loss: 0.006902 Epoch 7450 loss: 0.006521 Epoch 7500 loss: 0.006476 Epoch 7550 loss: 0.006627 Epoch 7600 loss: 0.006803 Epoch 7650 loss: 0.007007 Epoch 7700 loss: 0.006755 Epoch 7750 loss: 0.006311 Epoch 7800 loss: 0.006439 Epoch 7850 loss: 0.006668 Epoch 7900 loss: 0.007098 Epoch 7950 loss: 0.007805 Epoch 8000 loss: 0.006405 Epoch 8050 loss: 0.006763 Epoch 8100 loss: 0.006640 Epoch 8150 loss: 0.006424 Epoch 8200 loss: 0.006426 Epoch 8250 loss: 0.006576 Epoch 8300 loss: 0.006390 Epoch 8350 loss: 0.006189 Epoch 8400 loss: 0.006715 Epoch 8450 loss: 0.006440 Epoch 8500 loss: 0.006311 Epoch 8550 loss: 0.006487 Epoch 8600 loss: 0.006274 Epoch 8650 loss: 0.006384 Epoch 8700 loss: 0.006170 Epoch 8750 loss: 0.006158 Epoch 8800 loss: 0.006276 Epoch 8850 loss: 0.006181 Epoch 8900 loss: 0.006423 Epoch 8950 loss: 0.006242 Epoch 9000 loss: 0.005993 Epoch 9050 loss: 0.006317 Epoch 9100 loss: 0.005988 Epoch 9150 loss: 0.006118 Epoch 9200 loss: 0.006159 Epoch 9250 loss: 0.006061 Epoch 9300 loss: 0.006200 Epoch 9350 loss: 0.006279 Epoch 9400 loss: 0.006120 Epoch 9450 loss: 0.006521 Epoch 9500 loss: 0.006148 Epoch 9550 loss: 0.006173 Epoch 9600 loss: 0.005929 Epoch 9650 loss: 0.006773 Epoch 9700 loss: 0.005978 Epoch 9750 loss: 0.005842 Epoch 9800 loss: 0.005882 Epoch 9850 loss: 0.006384 Epoch 9900 loss: 0.006089 Epoch 9950 loss: 0.006302 Training time: 396.95 seconds
test_contextp = 1
sh_indexi = torch.randperm(img0_X.shape[0])
test_cont_img0_X = img0_X[sh_indexi][0:int(len(img0_X)*test_contextp/100)]
test_cont_img0_Y = img0_Y[sh_indexi][0:int(len(img0_X)*test_contextp/100)]
test_context_data = torch.cat((test_cont_img0_X, test_cont_img0_Y), dim = 1)
test_context_data.shape
torch.Size([40, 5])
plot_reconstructed_and_original_image(img0, hp_img2, targetnet_img2, img0_X, test_context_data, title="Hypernet")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
torch.Size([64, 64, 3])
torch.manual_seed(42)
img0_X, img0_Y, scaler_X = create_scaled_cmap(img0, rt = False)
img0_X.shape, img0_Y.shape
no_of_images = 50
images_hyper_data = []
datano = 4 #instead of 20 images
context_precent_start = 10 # increased the context size to 10 percent
context_percent_end = 50
for celebno in range(no_of_images):
image_X, image_Y, scaler_X = create_scaled_cmap(celeba_dataset[celebno], rt = False)
for i in range(datano):
contextp = context_precent_start + (i/datano)*(context_percent_end-context_precent_start)
sh_indexi = torch.randperm(image_X.shape[0])
cont_img0_X_shi = image_X[sh_indexi][0:int(len(img0_X)*contextp/100)]
cont_img0_Y_shi = image_Y[sh_indexi][0:int(len(img0_X)*contextp/100)]
context_data = torch.cat((cont_img0_X_shi, cont_img0_Y_shi), dim = 1)
train_img0_X_shi = image_X
train_img0_Y_shi = image_Y
datas = [context_data, train_img0_X_shi, train_img0_Y_shi]
images_hyper_data.append(datas)
len(images_hyper_data), images_hyper_data[0][0].shape, images_hyper_data[0][1].shape, images_hyper_data[0][2].shape
(200, torch.Size([409, 5]), torch.Size([4096, 2]), torch.Size([4096, 3]))
import time
torch.manual_seed(42)
start_time = time.time()
celeb_hp = HyperNet(activation = nn.ReLU() ).to(device)
celeb_targetnet = NN().to(device)
celeb_loss_list = train_hypernet(celeb_hp, celeb_targetnet, lr = 3e-4, hyper_data = images_hyper_data, epochs = 4500, verbose=True)
# reduced learning rate
end_time = time.time()
print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.072027 Epoch 50 loss: 0.060346 Epoch 100 loss: 0.060094 Epoch 150 loss: 0.059345 Epoch 200 loss: 0.059046 Epoch 250 loss: 0.057511 Epoch 300 loss: 0.057402 Epoch 350 loss: 0.056238 Epoch 400 loss: 0.055600 Epoch 450 loss: 0.055547 Epoch 500 loss: 0.055558 Epoch 550 loss: 0.054876 Epoch 600 loss: 0.054605 Epoch 650 loss: 0.053750 Epoch 700 loss: 0.052658 Epoch 750 loss: 0.052339 Epoch 800 loss: 0.051500 Epoch 850 loss: 0.051287 Epoch 900 loss: 0.050553 Epoch 950 loss: 0.050224 Epoch 1000 loss: 0.049445 Epoch 1050 loss: 0.049254 Epoch 1100 loss: 0.048603 Epoch 1150 loss: 0.048420 Epoch 1200 loss: 0.047918 Epoch 1250 loss: 0.047415 Epoch 1300 loss: 0.046563 Epoch 1350 loss: 0.045343 Epoch 1400 loss: 0.044586 Epoch 1450 loss: 0.043709 Epoch 1500 loss: 0.040924 Epoch 1550 loss: 0.040019 Epoch 1600 loss: 0.038758 Epoch 1650 loss: 0.037633 Epoch 1700 loss: 0.037167 Epoch 1750 loss: 0.036234 Epoch 1800 loss: 0.035238 Epoch 1850 loss: 0.034266 Epoch 1900 loss: 0.033301 Epoch 1950 loss: 0.031697 Epoch 2000 loss: 0.031217 Epoch 2050 loss: 0.030073 Epoch 2100 loss: 0.029144 Epoch 2150 loss: 0.028427 Epoch 2200 loss: 0.027817 Epoch 2250 loss: 0.027696 Epoch 2300 loss: 0.026979 Epoch 2350 loss: 0.026631 Epoch 2400 loss: 0.026351 Epoch 2450 loss: 0.025369 Epoch 2500 loss: 0.025343 Epoch 2550 loss: 0.025172 Epoch 2600 loss: 0.024360 Epoch 2650 loss: 0.024442 Epoch 2700 loss: 0.024184 Epoch 2750 loss: 0.024068 Epoch 2800 loss: 0.023908 Epoch 2850 loss: 0.023486 Epoch 2900 loss: 0.023201 Epoch 2950 loss: 0.023358 Epoch 3000 loss: 0.022979 Epoch 3050 loss: 0.023087 Epoch 3100 loss: 0.023546 Epoch 3150 loss: 0.022470 Epoch 3200 loss: 0.022867 Epoch 3250 loss: 0.022026 Epoch 3300 loss: 0.022216 Epoch 3350 loss: 0.022131 Epoch 3400 loss: 0.021967 Epoch 3450 loss: 0.021718 Epoch 3500 loss: 0.021446
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-50-f1ef09ed7c66> in <cell line: 7>() 5 celeb_hp = HyperNet(activation = nn.ReLU() ).to(device) 6 celeb_targetnet = NN().to(device) ----> 7 celeb_loss_list = train_hypernet(celeb_hp, celeb_targetnet, lr = 3e-4, hyper_data = images_hyper_data, epochs = 4500, verbose=True) 8 # reduced learning rate 9 end_time = time.time() <ipython-input-35-3255bd0398d6> in train_hypernet(hypernet, target_net, lr, hyper_data, epochs, verbose) 34 35 loss = criterion(outputs, train_img_Y_shi) ---> 36 loss.backward() 37 optimizer.step() 38 running_loss += loss.item() /usr/local/lib/python3.10/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs) 490 inputs=inputs, 491 ) --> 492 torch.autograd.backward( 493 self, gradient, retain_graph, create_graph, inputs=inputs 494 ) KeyboardInterrupt:
torch.save(celeb_hp.state_dict(), "celeb_hp_hypernet.pt")
torch.save(celeb_hp.state_dict(), "celeb_hp_hypernet.pt")
plt.plot(celeb_loss_list)
plt.xlabel("Epochs")
plt.ylabel("loss")
test_contextp = 50
celebno = 2
image_X, image_Y, scaler_X = create_scaled_cmap(celeba_dataset[celebno], rt = False)
sh_indexi = torch.randperm(img0_X.shape[0])
test_cont_img0_X = image_X[sh_indexi][0:int(len(img0_X)*test_contextp/100)]
test_cont_img0_Y = image_Y[sh_indexi][0:int(len(img0_X)*test_contextp/100)]
test_context_data = torch.cat((test_cont_img0_X, test_cont_img0_Y), dim = 1)
test_context_data.shape
torch.Size([2048, 5])
params = celeb_hp(test_context_data).mean(dim=0)
params.shape
torch.Size([8707])
flat_weights,unravel_fn= ravel_pytree(dict(celeb_targetnet.named_parameters()))
parameter_dictionary = unravel_fn(params)
outputs = torch.func.functional_call(celeb_targetnet, parameter_dictionary, img0_X)
print(output.shape)
torch.Size([64, 64, 3])
plt.imshow(output.detach().cpu())
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7cd7bb6545b0>
output.permute(1, 2, 0).shape
torch.Size([64, 3, 64])
img0_X
tensor([[-1.0000, -1.0000],
[-0.9683, -1.0000],
[-0.9365, -1.0000],
...,
[ 0.9365, 1.0000],
[ 0.9683, 1.0000],
[ 1.0000, 1.0000]], device='cuda:0')
plt.imshow(output.permute(1,2,0).detach().cpu())
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-79-0f48977eb536> in <cell line: 1>() ----> 1 plt.imshow(output.permute(1,2,0).detach().cpu()) /usr/local/lib/python3.10/dist-packages/matplotlib/pyplot.py in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, data, **kwargs) 2693 interpolation_stage=None, filternorm=True, filterrad=4.0, 2694 resample=None, url=None, data=None, **kwargs): -> 2695 __ret = gca().imshow( 2696 X, cmap=cmap, norm=norm, aspect=aspect, 2697 interpolation=interpolation, alpha=alpha, vmin=vmin, /usr/local/lib/python3.10/dist-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs) 1440 def inner(ax, *args, data=None, **kwargs): 1441 if data is None: -> 1442 return func(ax, *map(sanitize_sequence, args), **kwargs) 1443 1444 bound = new_sig.bind(ax, *args, **kwargs) /usr/local/lib/python3.10/dist-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs) 5663 **kwargs) 5664 -> 5665 im.set_data(X) 5666 im.set_alpha(alpha) 5667 if im.get_clip_path() is None: /usr/local/lib/python3.10/dist-packages/matplotlib/image.py in set_data(self, A) 708 if not (self._A.ndim == 2 709 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]): --> 710 raise TypeError("Invalid shape {} for image data" 711 .format(self._A.shape)) 712 TypeError: Invalid shape (64, 3, 64) for image data
params = hypernet(context).mean(dim=0)
flat_weights,unravel_fn= ravel_pytree(dict(targetnet.named_parameters()))
parameter_dictionary = unravel_fn(params)
outputs = torch.func.functional_call(targetnet, parameter_dictionary, X)
print(output.shape)
outputs = outputs.reshape(num_channels, height, width)
outputs = outputs.permute(1, 2, 0)
plot_reconstructed_and_original_image(celeba_dataset[celebno],
celeb_hp, celeb_targetnet, img0_X, test_context_data, title="Hypernet")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
torch.Size([64, 64, 3])
import torch
import os
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.distributions as dist
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
# Remove all the warnings
import warnings
warnings.filterwarnings('ignore')
# Set env CUDA_LAUNCH_BLOCKING=1
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Retina display
%config InlineBackend.figure_format = 'retina'
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
import os
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
])
root = "/content/drive/MyDrive/Dataset/CelebA_Dataset"
celeba_dataset = []
for filename in os.listdir(root):
if filename.endswith('.jpg'):
image_path = os.path.join(root, filename)
image = datasets.folder.default_loader(image_path)
image = transform(image)
celeba_dataset.append(image)
len(celeba_dataset)
761
img0 = transform(datasets.folder.default_loader("/content/drive/MyDrive/Dataset/CelebA_Dataset/000001.jpg"))
print(img0.shape)
plt.imshow(img0.permute(1,2,0))
torch.Size([3, 64, 64])
<matplotlib.image.AxesImage at 0x7c8a54cd0f10>
from sklearn import preprocessing
def create_scaled_cmap(img, rt = False):
"""
Creates a scaled image and a scaled colormap
"""
img= img
num_channels, height, width = img.shape
# Create a 2D grid of (x,y) coordinates
x_coords = torch.arange(width).repeat(height, 1)
y_coords = torch.arange(height).repeat(width, 1).t()
x_coords = x_coords.reshape(-1)
y_coords = y_coords.reshape(-1)
X = torch.stack([x_coords, y_coords], dim=1).float().to(device)
# Y = img.permute(1, 2, 0).reshape(-1, num_channels).float().to(device)
if rt == True:
Y = img.permute(1, 2, 0).reshape(-1, num_channels).float().to(device)
else:
Y = img.reshape(-1, num_channels).float().to(device)
scaler_X = preprocessing.MinMaxScaler(feature_range=(-1, 1)).fit(X.cpu())
scaled_X = torch.tensor(scaler_X.transform(X.cpu())).to(device).float()
return scaled_X, Y, scaler_X
img0_X_scaled, img0_Y, scaler_X = create_scaled_cmap(img0)
img0_X_scaled.shape, img0_Y.shape
(torch.Size([4096, 2]), torch.Size([4096, 3]))
# This neural process model will work for both encoder and decoder
# Smaller model
s = 64
class NN(nn.Module):
def _init_siren(self, activation_scale):
self.fc1.weight.data.uniform_(-1/self.fc1.in_features, 1/self.fc1.in_features)
for layers in [self.fc2, self.fc3, self.fc5]:
layers.weight.data.uniform_(-np.sqrt(6/self.fc2.in_features)/activation_scale,
np.sqrt(6/self.fc2.in_features)/activation_scale)
def __init__(self, inp_dim = 5, activation=torch.sin, n_out=3, activation_scale=1.0):
super().__init__()
self.activation = activation
self.activation_scale = activation_scale
self.fc1 = nn.Linear(inp_dim, s)
self.fc2 = nn.Linear(s, s)
self.fc3 = nn.Linear(s, s)
self.fc5 = nn.Linear(s, n_out) #gray scale image (1) or RGB (3)
if self.activation == torch.sin:
# init weights and biases for sine activation
self._init_siren(activation_scale=self.activation_scale)
def forward(self, x):
x = self.activation(self.activation_scale*self.fc1(x))
x = self.activation(self.activation_scale*self.fc2(x))
x = self.activation(self.activation_scale*self.fc3(x))
# x = self.activation(self.activation_scale*self.fc4(x))
return self.fc5(x)
img0_X_scaled.shape, img0_Y.shape
(torch.Size([4096, 2]), torch.Size([4096, 3]))
torch.manual_seed(40)
img0_np_data = []
datano = 10
context_percent_start = 10
context_percent_end = 50
for i in range(datano):
contp = context_percent_start + (i/datano)*(context_percent_end - context_percent_start)
sh_indexi = torch.randperm(img0_X_scaled.shape[0])
cont_img0_X_shi = img0_X_scaled[sh_indexi][0:int(len(img0_X_scaled)*contp/100)]
cont_img0_Y_shi = img0_Y[sh_indexi][0:int(len(img0_X_scaled)*contp/100)]
context_data = torch.cat((cont_img0_X_shi, cont_img0_Y_shi), dim = 1)
train_img0_X_shi = img0_X_scaled #[sh_indexi]
train_img0_Y_shi = img0_Y #[sh_indexi]
datas = [context_data, train_img0_X_shi, train_img0_Y_shi]
img0_np_data.append(datas)
len(img0_np_data), img0_np_data[0][0].shape, img0_np_data[0][1].shape, img0_np_data[0][2].shape
(10, torch.Size([409, 5]), torch.Size([4096, 2]), torch.Size([4096, 3]))
Training the dataset
K = 500
encoder2 = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
decoder2 = NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)
encoder_output = encoder2(img0_np_data[0][0])
enocoded_rep = encoder_output.mean(dim=0).repeat(img0_np_data[0][1].shape[0], 1)
print(img0_np_data[0][0].shape, encoder_output.shape, enocoded_rep.shape)
train_rep = torch.concat((img0_np_data[0][1], enocoded_rep), dim = 1)
print(train_rep.shape)
outputs = decoder2(train_rep)
print(outputs.shape)
def softplus(std, beta = 1, threshold = 20):
return (1/beta)*(torch.log(1 + torch.exp(std)))
print(type(softplus(outputs[:,3:])))
def normal_loss(mean, log_sigma, actual_val):
# std = softplus(std)
sigma = 0.1 + 0.9*softplus(log_sigma)
# type(std) # sigma of sigma*(1/2)
return -dist.Normal(mean, sigma).log_prob(actual_val).mean()
loss = normal_loss(outputs[:,:3], outputs[:,3:], img0_np_data[0][2])
loss
torch.Size([409, 5]) torch.Size([409, 500]) torch.Size([4096, 500]) torch.Size([4096, 502]) torch.Size([4096, 6]) <class 'torch.Tensor'>
tensor(0.8217, device='cuda:0', grad_fn=<NegBackward0>)
def softplus(std, beta = 1, threshold = 20):
return (1/beta)*(torch.log(1 + torch.exp(std)))
def normal_loss(mean, log_sigma, actual_val):
sigma = 0.1 + 0.9*softplus(log_sigma)
return -dist.Normal(mean, sigma).log_prob(actual_val).mean()
def train_np(encoder, decoder, np_data, lr, epochs, verbose=True):
"""
net: torch.nn.Module
lr: float
X: torch.Tensor of shape (num_samples, 2)
Y: torch.Tensor of shape (num_samples, 3)
"""
# criterion = nn.MSELoss()
optimizer = torch.optim.Adam([*encoder.parameters(), *decoder.parameters()], lr=lr)
datano = len(np_data)
total_loss = []
for epoch in range(epochs):
running_loss = 0.0
for i in range(len(np_data)):
context_data, train_dog_X_shi, train_dog_Y_shi = np_data[i]
optimizer.zero_grad()
encoder_output = encoder(context_data)
enocoded_rep = encoder_output.mean(dim=0).repeat(train_dog_X_shi.shape[0], 1)
train_rep = torch.concat((train_dog_X_shi, enocoded_rep), dim = 1)
outputs = decoder(train_rep)
# loss = criterion(outputs[:,:3], train_dog_Y_shi)
loss = normal_loss(outputs[:,:3], outputs[:,3:], train_dog_Y_shi)
loss.backward()
optimizer.step()
running_loss += loss.item()
# if verbose and epoch % 1 == 0:
# # print(encoder.state_dict()["fc1.weight"][0:2, 0:2], decoder.state_dict()["fc1.weight"][0:2, 0:2])
# print(f"Epoch {epoch} loss: {loss.item():.6f}")
total_loss.append(running_loss)
if verbose and epoch % 50 == 0:
# if verbose and epoch % 1 == 0:
# print(encoder.state_dict()["fc1.weight"][0:2, 0:2], decoder.state_dict()["fc1.weight"][0:2, 0:2])
print(f"Epoch {epoch} loss: {running_loss/datano:.6f}")
return total_loss
import time
torch.manual_seed(0)
# can experiment with k
K = 500
encoder2 = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
decoder2 = NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)
start_time = time.time()
loss_list = train_np(encoder2, decoder2, img0_np_data, lr=1e-3, epochs=8000, verbose=True)
end_time = time.time()
print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 1.337767 Epoch 50 loss: 0.094281 Epoch 100 loss: 0.091550 Epoch 150 loss: 0.085551 Epoch 200 loss: 0.078786 Epoch 250 loss: 0.075036 Epoch 300 loss: 0.068512 Epoch 350 loss: 0.057991 Epoch 400 loss: 0.032758 Epoch 450 loss: -0.116420 Epoch 500 loss: -0.232012 Epoch 550 loss: -0.305583 Epoch 600 loss: -0.349394 Epoch 650 loss: -0.386280 Epoch 700 loss: -0.409056 Epoch 750 loss: -0.445883 Epoch 800 loss: -0.460155 Epoch 850 loss: -0.474480 Epoch 900 loss: -0.505245 Epoch 950 loss: -0.487306 Epoch 1000 loss: -0.353294 Epoch 1050 loss: -0.533210 Epoch 1100 loss: -0.557027 Epoch 1150 loss: -0.426544 Epoch 1200 loss: -0.576683 Epoch 1250 loss: -0.502492 Epoch 1300 loss: -0.594420 Epoch 1350 loss: -0.607830 Epoch 1400 loss: -0.585463 Epoch 1450 loss: -0.637908 Epoch 1500 loss: -0.650928 Epoch 1550 loss: -0.464352 Epoch 1600 loss: -0.671424 Epoch 1650 loss: -0.626369 Epoch 1700 loss: -0.683977 Epoch 1750 loss: -0.699354 Epoch 1800 loss: -0.689624 Epoch 1850 loss: -0.718512 Epoch 1900 loss: -0.661768 Epoch 1950 loss: -0.701058 Epoch 2000 loss: -0.762148 Epoch 2050 loss: -0.741034 Epoch 2100 loss: -0.774478 Epoch 2150 loss: -0.762232 Epoch 2200 loss: -0.792295 Epoch 2250 loss: -0.753801 Epoch 2300 loss: -0.802900 Epoch 2350 loss: -0.579035 Epoch 2400 loss: -0.814304 Epoch 2450 loss: -0.759662 Epoch 2500 loss: -0.820780 Epoch 2550 loss: -0.829672 Epoch 2600 loss: -0.690268 Epoch 2650 loss: -0.836858 Epoch 2700 loss: -0.813810 Epoch 2750 loss: -0.848221 Epoch 2800 loss: -0.854806 Epoch 2850 loss: -0.763804 Epoch 2900 loss: -0.863535 Epoch 2950 loss: -0.871811 Epoch 3000 loss: -0.555845 Epoch 3050 loss: -0.878646 Epoch 3100 loss: -0.885052 Epoch 3150 loss: -0.739801 Epoch 3200 loss: -0.890640 Epoch 3250 loss: -0.889438 Epoch 3300 loss: -0.900909 Epoch 3350 loss: -0.782015 Epoch 3400 loss: -0.909211 Epoch 3450 loss: -0.911320 Epoch 3500 loss: -0.912907 Epoch 3550 loss: -0.922259 Epoch 3600 loss: -0.760309 Epoch 3650 loss: -0.929416 Epoch 3700 loss: -0.868978 Epoch 3750 loss: -0.927377 Epoch 3800 loss: -0.938281 Epoch 3850 loss: -0.939244 Epoch 3900 loss: -0.930846 Epoch 3950 loss: -0.939761 Epoch 4000 loss: -0.952568 Epoch 4050 loss: -0.924327 Epoch 4100 loss: -0.959611 Epoch 4150 loss: -0.962178 Epoch 4200 loss: -0.960338 Epoch 4250 loss: -0.967879 Epoch 4300 loss: -0.972552 Epoch 4350 loss: -0.968258 Epoch 4400 loss: -0.970368 Epoch 4450 loss: -0.931210 Epoch 4500 loss: -0.979965 Epoch 4550 loss: -0.755416 Epoch 4600 loss: -0.985051 Epoch 4650 loss: -0.989418 Epoch 4700 loss: -0.881498 Epoch 4750 loss: -0.993088 Epoch 4800 loss: -0.976235 Epoch 4850 loss: -0.996279 Epoch 4900 loss: -0.999940 Epoch 4950 loss: -0.921534 Epoch 5000 loss: -1.002399 Epoch 5050 loss: -1.006093 Epoch 5100 loss: -0.990569 Epoch 5150 loss: -1.008689 Epoch 5200 loss: -1.000557 Epoch 5250 loss: -0.952158 Epoch 5300 loss: -1.013231 Epoch 5350 loss: -0.943541 Epoch 5400 loss: -1.013609 Epoch 5450 loss: -1.007544 Epoch 5500 loss: -1.012452 Epoch 5550 loss: -1.018967 Epoch 5600 loss: -1.024164 Epoch 5650 loss: -0.907021 Epoch 5700 loss: -1.026700 Epoch 5750 loss: -1.028270 Epoch 5800 loss: -0.917459 Epoch 5850 loss: -1.030782 Epoch 5900 loss: -1.032855 Epoch 5950 loss: -0.976612 Epoch 6000 loss: -1.031469 Epoch 6050 loss: -1.032512 Epoch 6100 loss: -1.011055 Epoch 6150 loss: -1.040253 Epoch 6200 loss: -1.044241 Epoch 6250 loss: -0.965589 Epoch 6300 loss: -1.043470 Epoch 6350 loss: -1.048069 Epoch 6400 loss: -0.946986 Epoch 6450 loss: -1.050458 Epoch 6500 loss: -1.050313 Epoch 6550 loss: -0.924406 Epoch 6600 loss: -1.053265 Epoch 6650 loss: -1.056274 Epoch 6700 loss: -1.010929 Epoch 6750 loss: -1.058735 Epoch 6800 loss: -1.056873 Epoch 6850 loss: -0.995843 Epoch 6900 loss: -1.063446 Epoch 6950 loss: -1.004386 Epoch 7000 loss: -1.064945 Epoch 7050 loss: -1.051375 Epoch 7100 loss: -1.065526 Epoch 7150 loss: -1.067781 Epoch 7200 loss: -1.004980 Epoch 7250 loss: -1.072607 Epoch 7300 loss: -1.024302 Epoch 7350 loss: -1.074019 Epoch 7400 loss: -1.072818 Epoch 7450 loss: -1.053316 Epoch 7500 loss: -1.065668 Epoch 7550 loss: -0.961100 Epoch 7600 loss: -1.073951 Epoch 7650 loss: -1.082017 Epoch 7700 loss: -1.001805 Epoch 7750 loss: -1.077227 Epoch 7800 loss: -1.083825 Epoch 7850 loss: -1.081036 Epoch 7900 loss: -1.087032 Epoch 7950 loss: -1.051721 Training time: 316.47 seconds
# loss_list2 = train_np(encoder2, decoder2, img0_np_data, lr=1e-3, epochs=4000, verbose=True)
torch.save(encoder2.state_dict(), "img0_encoder2_hypernet.pt")
torch.save(decoder2.state_dict(), "img0_decoder2_hypernet.pt")
plt.plot(loss_list)
plt.xlabel("Iterations")
plt.ylabel("loss")
Text(0, 0.5, 'loss')
def plot_np_image(encoder, decoder, datano, key, context_percent, title = ""):
torch.manual_seed(key)
scaled_X, Y, scaler_X = create_scaled_cmap(celeba_dataset[datano], rt = True)
#now, the context has also changed
sh_index = torch.randperm(scaled_X.shape[0])
# if int(len(scaled_X)*context_percent/100) > 6000:
# print("context size is too big")
# return None
cont_img_X = scaled_X[sh_index][0:int(len(scaled_X)*context_percent/100)]
cont_img_Y = Y[sh_index][0:int(len(scaled_X)*context_percent/100)]
context = torch.cat((cont_img_X, cont_img_Y), dim = 1)
encoder_output = encoder(context)
encoded_rep = encoder_output.mean(dim=0).repeat(scaled_X.shape[0], 1)
train_rep = torch.cat((scaled_X, encoded_rep), dim = 1)
output = decoder(train_rep)
var = 0.1 + 0.9*softplus(output[:,3:]**2) #output[:,3:]**2
output = output[:,:3]
num_channels, height, width = celeba_dataset[datano].shape
output = output.reshape(num_channels, height, width)
output = output.permute(1, 2, 0)
fig = plt.figure(figsize=(12, 6))
gs = gridspec.GridSpec(1, 4, width_ratios=[1, 1, 1, 1])
ax0 = plt.subplot(gs[0])
ax1 = plt.subplot(gs[1])
ax2 = plt.subplot(gs[2])
ax3 = plt.subplot(gs[3])
ax0.imshow(celeba_dataset[datano].cpu().permute(1, 2, 0))
ax0.set_title("Original Image")
ax0.axis("off")
ax1.imshow(output.detach().cpu())
ax1.set_title("Reconstructed Image")
ax1.axis("off")
actual_img = torch.ones(celeba_dataset[datano].shape).permute(1, 2, 0)
# actual_img = (celeba_dataset[datano][0].cpu()*0).permute(1, 2, 0)
cont_img_X_unscaled = scaler_X.inverse_transform(cont_img_X.cpu())
for i,x in enumerate(cont_img_X_unscaled):
actual_img[int(x[1]+0.5), int(x[0]+0.5)] = torch.tensor(cont_img_Y[i].cpu().detach().numpy())
ax2.imshow(actual_img)
# ax2.scatter(cont_img_X[:, 0].detach().cpu(), cont_img_X[:, 1].detach().cpu(), s=10, c='r')
ax2.set_title("Context Points")
ax2.axis("off")
var = var.reshape(num_channels, height, width)
var = var.permute(1, 2, 0)
ax3.imshow(var.detach().cpu())
ax3.set_title("Variance")
ax3.axis("off")
fig.suptitle(title, y=0.9)
plt.tight_layout()
plot_np_image(encoder2, decoder2, datano = 0, key = 41, context_percent = 20, title = "20% context K = 500")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
import time
torch.manual_seed(0)
# can experiment with k
K = 200
encoder3 = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
decoder3 = NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)
start_time = time.time()
loss_list3 = train_np(encoder3, decoder3, img0_np_data, lr=1e-3, epochs=8000, verbose=True)
end_time = time.time()
print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.985983 Epoch 50 loss: 0.094548 Epoch 100 loss: 0.094396 Epoch 150 loss: 0.094360 Epoch 200 loss: 0.094433 Epoch 250 loss: 0.086920 Epoch 300 loss: 0.076237 Epoch 350 loss: 0.068208 Epoch 400 loss: 0.061454 Epoch 450 loss: 0.051312 Epoch 500 loss: 0.002067 Epoch 550 loss: -0.171355 Epoch 600 loss: -0.258447 Epoch 650 loss: -0.351243 Epoch 700 loss: -0.318880 Epoch 750 loss: -0.430725 Epoch 800 loss: -0.443474 Epoch 850 loss: -0.442860 Epoch 900 loss: -0.471925 Epoch 950 loss: -0.354867 Epoch 1000 loss: -0.487307 Epoch 1050 loss: -0.485608 Epoch 1100 loss: -0.503016 Epoch 1150 loss: -0.463705 Epoch 1200 loss: -0.516598 Epoch 1250 loss: -0.413420 Epoch 1300 loss: -0.511021 Epoch 1350 loss: -0.534636 Epoch 1400 loss: -0.533560 Epoch 1450 loss: -0.550606 Epoch 1500 loss: -0.501212 Epoch 1550 loss: -0.564356 Epoch 1600 loss: -0.565869 Epoch 1650 loss: -0.566113 Epoch 1700 loss: -0.583977 Epoch 1750 loss: -0.594680 Epoch 1800 loss: -0.601825 Epoch 1850 loss: -0.603967 Epoch 1900 loss: -0.630418 Epoch 1950 loss: -0.629093 Epoch 2000 loss: -0.642894 Epoch 2050 loss: -0.628538 Epoch 2100 loss: -0.660608 Epoch 2150 loss: -0.671160 Epoch 2200 loss: -0.680648 Epoch 2250 loss: -0.657872 Epoch 2300 loss: -0.689692 Epoch 2350 loss: -0.688498 Epoch 2400 loss: -0.700271 Epoch 2450 loss: -0.712134 Epoch 2500 loss: -0.718355 Epoch 2550 loss: -0.538734 Epoch 2600 loss: -0.725233 Epoch 2650 loss: -0.732642 Epoch 2700 loss: -0.717478 Epoch 2750 loss: -0.742109 Epoch 2800 loss: -0.751388 Epoch 2850 loss: -0.749122 Epoch 2900 loss: -0.760358 Epoch 2950 loss: -0.767813 Epoch 3000 loss: -0.562763 Epoch 3050 loss: -0.774568 Epoch 3100 loss: -0.715387 Epoch 3150 loss: -0.785606 Epoch 3200 loss: -0.517251 Epoch 3250 loss: -0.788772 Epoch 3300 loss: -0.793651 Epoch 3350 loss: -0.796951 Epoch 3400 loss: -0.798574 Epoch 3450 loss: -0.807807 Epoch 3500 loss: -0.807657 Epoch 3550 loss: -0.814677 Epoch 3600 loss: -0.698318 Epoch 3650 loss: -0.820988 Epoch 3700 loss: -0.786445 Epoch 3750 loss: -0.827814 Epoch 3800 loss: -0.832542 Epoch 3850 loss: -0.604163 Epoch 3900 loss: -0.835482 Epoch 3950 loss: -0.836838 Epoch 4000 loss: -0.842324 Epoch 4050 loss: -0.741636 Epoch 4100 loss: -0.848315 Epoch 4150 loss: -0.809173 Epoch 4200 loss: -0.851881 Epoch 4250 loss: -0.854729 Epoch 4300 loss: -0.839670 Epoch 4350 loss: -0.856656 Epoch 4400 loss: -0.861078 Epoch 4450 loss: -0.862912 Epoch 4500 loss: -0.866721 Epoch 4550 loss: -0.860558 Epoch 4600 loss: -0.866106 Epoch 4650 loss: -0.840890 Epoch 4700 loss: -0.871417 Epoch 4750 loss: -0.877970 Epoch 4800 loss: -0.685538 Epoch 4850 loss: -0.881884 Epoch 4900 loss: -0.885006 Epoch 4950 loss: -0.852265 Epoch 5000 loss: -0.888602 Epoch 5050 loss: -0.887027 Epoch 5100 loss: -0.886767 Epoch 5150 loss: -0.886860 Epoch 5200 loss: -0.894198 Epoch 5250 loss: -0.892315 Epoch 5300 loss: -0.898365 Epoch 5350 loss: -0.742378 Epoch 5400 loss: -0.901049 Epoch 5450 loss: -0.901360 Epoch 5500 loss: -0.902117 Epoch 5550 loss: -0.899075 Epoch 5600 loss: -0.907264 Epoch 5650 loss: -0.902203 Epoch 5700 loss: -0.906515 Epoch 5750 loss: -0.904309 Epoch 5800 loss: -0.914406 Epoch 5850 loss: -0.914590 Epoch 5900 loss: -0.911589 Epoch 5950 loss: -0.917603 Epoch 6000 loss: -0.883401 Epoch 6050 loss: -0.920532 Epoch 6100 loss: -0.912924 Epoch 6150 loss: -0.922672 Epoch 6200 loss: -0.924746 Epoch 6250 loss: -0.883664 Epoch 6300 loss: -0.924836 Epoch 6350 loss: -0.925393 Epoch 6400 loss: -0.926199 Epoch 6450 loss: -0.875144 Epoch 6500 loss: -0.930490 Epoch 6550 loss: -0.924770 Epoch 6600 loss: -0.855461 Epoch 6650 loss: -0.934010 Epoch 6700 loss: -0.907176 Epoch 6750 loss: -0.935576 Epoch 6800 loss: -0.932074 Epoch 6850 loss: -0.937173 Epoch 6900 loss: -0.924708 Epoch 6950 loss: -0.939628 Epoch 7000 loss: -0.888961 Epoch 7050 loss: -0.940577 Epoch 7100 loss: -0.938378 Epoch 7150 loss: -0.941364 Epoch 7200 loss: -0.915255 Epoch 7250 loss: -0.941812 Epoch 7300 loss: -0.944872 Epoch 7350 loss: -0.917466 Epoch 7400 loss: -0.946301 Epoch 7450 loss: -0.812070 Epoch 7500 loss: -0.947961 Epoch 7550 loss: -0.949372 Epoch 7600 loss: -0.750076 Epoch 7650 loss: -0.949743 Epoch 7700 loss: -0.951500 Epoch 7750 loss: -0.931081 Epoch 7800 loss: -0.937685 Epoch 7850 loss: -0.953335 Epoch 7900 loss: -0.952670 Epoch 7950 loss: -0.954446 Training time: 330.42 seconds
plt.plot(loss_list3)
plt.xlabel("Iterations")
plt.ylabel("loss")
Text(0, 0.5, 'loss')
plot_np_image(encoder3, decoder3, datano = 0, key = 41, context_percent = 20, title = "20% context K = 200")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Higher variance for lesser K
import time
torch.manual_seed(0)
# can experiment with k
K = 1000
encoder4 = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
decoder4 = NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)
start_time = time.time()
loss_list4 = train_np(encoder4, decoder4, img0_np_data, lr=1e-3, epochs=8000, verbose=True)
end_time = time.time()
print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 3.247491 Epoch 50 loss: 0.095072 Epoch 100 loss: 0.094696 Epoch 150 loss: 0.095501 Epoch 200 loss: 0.092775 Epoch 250 loss: 0.105793 Epoch 300 loss: 0.078307 Epoch 350 loss: 0.960751 Epoch 400 loss: 0.472558 Epoch 450 loss: 0.091825 Epoch 500 loss: 0.088676 Epoch 550 loss: 0.087290 Epoch 600 loss: 0.085511 Epoch 650 loss: 0.516163 Epoch 700 loss: 0.225607 Epoch 750 loss: 0.090968 Epoch 800 loss: 0.100098 Epoch 850 loss: 0.088283 Epoch 900 loss: 0.080500 Epoch 950 loss: 0.080115 Epoch 1000 loss: 0.364954 Epoch 1050 loss: 0.084921 Epoch 1100 loss: 0.078579 Epoch 1150 loss: 0.076788 Epoch 1200 loss: 0.081680 Epoch 1250 loss: 0.078172 Epoch 1300 loss: 0.069085 Epoch 1350 loss: 0.064416 Epoch 1400 loss: 0.296190 Epoch 1450 loss: 0.061623 Epoch 1500 loss: 0.050863 Epoch 1550 loss: 0.006544 Epoch 1600 loss: 0.018891 Epoch 1650 loss: -0.085501 Epoch 1700 loss: -0.113087 Epoch 1750 loss: -0.225822 Epoch 1800 loss: -0.237431 Epoch 1850 loss: -0.276261 Epoch 1900 loss: -0.392886 Epoch 1950 loss: -0.414422 Epoch 2000 loss: -0.444768 Epoch 2050 loss: -0.464655 Epoch 2100 loss: -0.506322 Epoch 2150 loss: -0.525111 Epoch 2200 loss: -0.593114 Epoch 2250 loss: -0.599869 Epoch 2300 loss: -0.608259 Epoch 2350 loss: -0.640627 Epoch 2400 loss: -0.578990 Epoch 2450 loss: -0.684827 Epoch 2500 loss: -0.749774 Epoch 2550 loss: -0.768743 Epoch 2600 loss: -0.657960 Epoch 2650 loss: -0.783812 Epoch 2700 loss: -0.628939 Epoch 2750 loss: -0.817731 Epoch 2800 loss: -0.764647 Epoch 2850 loss: -0.847248 Epoch 2900 loss: -0.866466 Epoch 2950 loss: -0.645291 Epoch 3000 loss: -0.770316 Epoch 3050 loss: -0.874695 Epoch 3100 loss: -0.901799 Epoch 3150 loss: -0.864566 Epoch 3200 loss: -0.919835 Epoch 3250 loss: -0.858242 Epoch 3300 loss: -0.905720 Epoch 3350 loss: -0.878959 Epoch 3400 loss: -0.859680 Epoch 3450 loss: -0.489101 Epoch 3500 loss: -0.952121 Epoch 3550 loss: -0.941618 Epoch 3600 loss: -0.894202 Epoch 3650 loss: -0.915352 Epoch 3700 loss: -0.953912 Epoch 3750 loss: -0.722261 Epoch 3800 loss: -0.959066 Epoch 3850 loss: -0.724264 Epoch 3900 loss: -0.996326 Epoch 3950 loss: -0.999575 Epoch 4000 loss: -0.811981 Epoch 4050 loss: -0.868557 Epoch 4100 loss: -0.856301 Epoch 4150 loss: -0.871763 Epoch 4200 loss: -0.902275 Epoch 4250 loss: -0.870940 Epoch 4300 loss: -1.014184 Epoch 4350 loss: -0.979585 Epoch 4400 loss: -1.015361 Epoch 4450 loss: -0.978255 Epoch 4500 loss: -0.834727 Epoch 4550 loss: -0.914563 Epoch 4600 loss: -1.051477 Epoch 4650 loss: -1.026398 Epoch 4700 loss: -1.011141 Epoch 4750 loss: -1.044407 Epoch 4800 loss: -0.998591 Epoch 4850 loss: -1.062151 Epoch 4900 loss: -0.842098 Epoch 4950 loss: -1.062960 Epoch 5000 loss: -0.959533 Epoch 5050 loss: -1.064741 Epoch 5100 loss: -0.958881 Epoch 5150 loss: -1.076099 Epoch 5200 loss: -0.927501 Epoch 5250 loss: -1.054025 Epoch 5300 loss: -0.935108 Epoch 5350 loss: -1.082859 Epoch 5400 loss: -0.941729 Epoch 5450 loss: -1.084702 Epoch 5500 loss: -0.919471 Epoch 5550 loss: -1.005004 Epoch 5600 loss: -1.025450 Epoch 5650 loss: -1.093088 Epoch 5700 loss: -0.580471 Epoch 5750 loss: -1.090108 Epoch 5800 loss: -1.047101 Epoch 5850 loss: -1.105287 Epoch 5900 loss: -1.054601 Epoch 5950 loss: -1.036265 Epoch 6000 loss: -0.904680 Epoch 6050 loss: -0.792468 Epoch 6100 loss: -1.111913 Epoch 6150 loss: -0.964587 Epoch 6200 loss: -1.117073 Epoch 6250 loss: -1.122940 Epoch 6300 loss: -0.303059 Epoch 6350 loss: -1.100989 Epoch 6400 loss: -1.123604 Epoch 6450 loss: -1.111797 Epoch 6500 loss: -1.066253 Epoch 6550 loss: -1.021821 Epoch 6600 loss: -1.128251 Epoch 6650 loss: -1.057411 Epoch 6700 loss: -1.126673 Epoch 6750 loss: -1.105785 Epoch 6800 loss: -0.909334 Epoch 6850 loss: -0.948805 Epoch 6900 loss: -1.133896 Epoch 6950 loss: -1.073505 Epoch 7000 loss: -1.136847 Epoch 7050 loss: -1.073588 Epoch 7100 loss: -1.009102 Epoch 7150 loss: -1.140496 Epoch 7200 loss: -1.034243 Epoch 7250 loss: -1.141284 Epoch 7300 loss: -1.034567 Epoch 7350 loss: -0.994248 Epoch 7400 loss: -1.137774 Epoch 7450 loss: -1.035262 Epoch 7500 loss: -1.130714 Epoch 7550 loss: -1.141972 Epoch 7600 loss: -1.146804 Epoch 7650 loss: -1.139201 Epoch 7700 loss: -1.146221 Epoch 7750 loss: -1.093645 Epoch 7800 loss: -0.959759 Epoch 7850 loss: -1.150753 Epoch 7900 loss: -0.820156 Epoch 7950 loss: -1.154293 Training time: 317.62 seconds
plt.plot(loss_list4)
plt.xlabel("Iterations")
plt.ylabel("loss")
Text(0, 0.5, 'loss')
plot_np_image(encoder4, decoder4, datano = 0, key = 41, context_percent = 20, title = "20% context K = 1000")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
10 dp, 8000 epochs = 4 min, 200 dp, 3500 epochs = 35 min
200*3500*4/(10*8000)
35.0
torch.manual_seed(42)
celeb_np_data = []
celebno = 50
datano = 4 # 100
context_percent_start = 10
context_percent_end = 50
for i in range(celebno):
scaled_X, Y, scaler_X = create_scaled_cmap(celeba_dataset[i])
for j in range(datano):
context_percent = context_percent_start + (i/datano)*(context_percent_end - context_percent_start)
sh_index = torch.randperm(scaled_X.shape[0])
cont_img_Xi = scaled_X[sh_index][0:int(len(scaled_X)*context_percent/100)]
cont_img_Yi = Y[sh_index][0:int(len(scaled_X)*context_percent/100)]
context_data = torch.cat((cont_img_Xi, cont_img_Yi), dim = 1)
train_img_Xi = scaled_X[sh_index]
train_img_Yi = Y[sh_index]
datas = [context_data, train_img_Xi, train_img_Yi]
celeb_np_data.append(datas)
len(celeb_np_data) ,celeb_np_data[0][0].shape, celeb_np_data[0][1].shape, celeb_np_data[0][2].shape
(200, torch.Size([409, 5]), torch.Size([4096, 2]), torch.Size([4096, 3]))
import time
torch.manual_seed(0)
K = 1000 # can change based on experiment
celeb_encoder = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
celeb_decoder = NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)
start_time = time.time()
celeb_loss_list = train_np(celeb_encoder, celeb_decoder, celeb_np_data, lr=1e-3, epochs= 3500, verbose=True)
end_time = time.time()
print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.561838 Epoch 50 loss: 0.062497 Epoch 100 loss: 0.049615 Epoch 150 loss: -0.013535 Epoch 200 loss: -0.008642 Epoch 250 loss: -0.063802 Epoch 300 loss: -0.103127 Epoch 350 loss: -0.132578 Epoch 400 loss: -0.150624 Epoch 450 loss: -0.161050 Epoch 500 loss: -0.163209 Epoch 550 loss: -0.184284 Epoch 600 loss: -0.186982 Epoch 650 loss: -0.179719 Epoch 700 loss: -0.203929 Epoch 750 loss: -0.195653 Epoch 800 loss: -0.209538 Epoch 850 loss: -0.202713 Epoch 900 loss: -0.212755 Epoch 950 loss: -0.203270 Epoch 1000 loss: -0.210044 Epoch 1050 loss: -0.209445 Epoch 1100 loss: -0.223716 Epoch 1150 loss: -0.230180 Epoch 1200 loss: -0.219785 Epoch 1250 loss: -0.221717 Epoch 1300 loss: -0.229299 Epoch 1350 loss: -0.354515 Epoch 1400 loss: -0.472582 Epoch 1450 loss: -0.508528 Epoch 1500 loss: -0.522056 Epoch 1550 loss: -0.524381 Epoch 1600 loss: -0.533695 Epoch 1650 loss: -0.543165 Epoch 1700 loss: -0.532182 Epoch 1750 loss: -0.565511 Epoch 1800 loss: -0.564247 Epoch 1850 loss: -0.541689 Epoch 1900 loss: -0.565920 Epoch 1950 loss: -0.556600 Epoch 2000 loss: -0.563846 Epoch 2050 loss: -0.586694 Epoch 2100 loss: -0.570260 Epoch 2150 loss: -0.581060 Epoch 2200 loss: -0.584690 Epoch 2250 loss: -0.589820 Epoch 2300 loss: -0.581520 Epoch 2350 loss: -0.588501 Epoch 2400 loss: -0.588752 Epoch 2450 loss: -0.604451 Epoch 2500 loss: -0.589753 Epoch 2550 loss: -0.565134 Epoch 2600 loss: -0.581715 Epoch 2650 loss: -0.610918 Epoch 2700 loss: -0.577788 Epoch 2750 loss: -0.594610 Epoch 2800 loss: -0.606971 Epoch 2850 loss: -0.606520 Epoch 2900 loss: -0.607718 Epoch 2950 loss: -0.594243 Epoch 3000 loss: -0.576128 Epoch 3050 loss: -0.603738 Epoch 3100 loss: -0.586131 Epoch 3150 loss: -0.601565 Epoch 3200 loss: -0.614933 Epoch 3250 loss: -0.616316 Epoch 3300 loss: -0.609108 Epoch 3350 loss: -0.624380 Epoch 3400 loss: -0.611612 Epoch 3450 loss: -0.622663 Training time: 2831.85 seconds
torch.save(celeb_encoder.state_dict(), "celeb_encoder1_hypernet.pt")
torch.save(celeb_decoder.state_dict(), "celeb_decoder1_hypernet.pt")
plt.plot(celeb_loss_list)
plt.xlabel("Iterations")
plt.ylabel("loss")
Text(0, 0.5, 'loss')
plot_np_image(celeb_encoder, celeb_decoder, datano = 3, key = 41, context_percent = 10, title = "10% context")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
plot_np_image(celeb_encoder, celeb_decoder, datano = 3, key = 41, context_percent = 50, title = "50% context, K = 1000")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
plot_np_image(celeb_encoder, celeb_decoder, datano = 3, key = 41, context_percent = 1, title = "1% context, K = 1000")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
plot_np_image(celeb_encoder, celeb_decoder, datano = 5, key = 41, context_percent = 100, title = "100% context, K = 1000")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
torch.manual_seed(42)
celeb_np_data = []
celebno = 50
datano = 4 # 100
context_percent_start = 10
context_percent_end = 50
for i in range(celebno):
scaled_X, Y, scaler_X = create_scaled_cmap(celeba_dataset[i])
for j in range(datano):
context_percent = context_percent_start + (i/datano)*(context_percent_end - context_percent_start)
sh_index = torch.randperm(scaled_X.shape[0])
cont_img_Xi = scaled_X[sh_index][0:int(len(scaled_X)*context_percent/100)]
cont_img_Yi = Y[sh_index][0:int(len(scaled_X)*context_percent/100)]
context_data = torch.cat((cont_img_Xi, cont_img_Yi), dim = 1)
train_img_Xi = scaled_X[sh_index]
train_img_Yi = Y[sh_index]
datas = [context_data, train_img_Xi, train_img_Yi]
celeb_np_data.append(datas)
len(celeb_np_data) ,celeb_np_data[0][0].shape, celeb_np_data[0][1].shape, celeb_np_data[0][2].shape
(200, torch.Size([409, 5]), torch.Size([4096, 2]), torch.Size([4096, 3]))
import time
torch.manual_seed(0)
K = 1000 # can change based on experiment
celeb_encoder2 = NN(inp_dim = 2+3, activation=torch.sin, n_out=K ).to(device)
celeb_decoder2 = NN(inp_dim = K+2, activation=torch.sin, n_out=6).to(device)
start_time = time.time()
celeb_loss_list2 = train_np(celeb_encoder2, celeb_decoder2, celeb_np_data, lr=1e-3, epochs= 6000, verbose=True)
end_time = time.time()
print(f"Training time: {end_time-start_time:.2f} seconds")
Epoch 0 loss: 0.479685 Epoch 50 loss: 0.168451 Epoch 100 loss: 0.089300 Epoch 150 loss: 0.095178 Epoch 200 loss: 0.069774 Epoch 250 loss: 0.023266 Epoch 300 loss: -0.022957 Epoch 350 loss: -0.045284 Epoch 400 loss: -0.066094 Epoch 450 loss: -0.083994 Epoch 500 loss: -0.102610 Epoch 550 loss: -0.112196 Epoch 600 loss: -0.118784 Epoch 650 loss: -0.126227 Epoch 700 loss: -0.124259 Epoch 750 loss: -0.124713 Epoch 800 loss: -0.130140 Epoch 850 loss: -0.138738 Epoch 900 loss: -0.143222 Epoch 950 loss: -0.142787 Epoch 1000 loss: -0.149772 Epoch 1050 loss: -0.145987 Epoch 1100 loss: -0.183938 Epoch 1150 loss: -0.297947 Epoch 1200 loss: -0.390763 Epoch 1250 loss: -0.417370 Epoch 1300 loss: -0.438389 Epoch 1350 loss: -0.451033 Epoch 1400 loss: -0.494825 Epoch 1450 loss: -0.488731 Epoch 1500 loss: -0.514693 Epoch 1550 loss: -0.542114 Epoch 1600 loss: -0.530419 Epoch 1650 loss: -0.525035 Epoch 1700 loss: -0.531059 Epoch 1750 loss: -0.543996 Epoch 1800 loss: -0.556629 Epoch 1850 loss: -0.549620 Epoch 1900 loss: -0.563347 Epoch 1950 loss: -0.552585 Epoch 2000 loss: -0.555627 Epoch 2050 loss: -0.569125 Epoch 2100 loss: -0.581341 Epoch 2150 loss: -0.550785 Epoch 2200 loss: -0.577023 Epoch 2250 loss: -0.594627 Epoch 2300 loss: -0.564444 Epoch 2350 loss: -0.584900 Epoch 2400 loss: -0.597912 Epoch 2450 loss: -0.580392 Epoch 2500 loss: -0.586710 Epoch 2550 loss: -0.581483 Epoch 2600 loss: -0.580890 Epoch 2650 loss: -0.604032 Epoch 2700 loss: -0.589725 Epoch 2750 loss: -0.588449 Epoch 2800 loss: -0.588983 Epoch 2850 loss: -0.596554 Epoch 2900 loss: -0.607114 Epoch 2950 loss: -0.596147 Epoch 3000 loss: -0.599351 Epoch 3050 loss: -0.602844 Epoch 3100 loss: -0.606371 Epoch 3150 loss: -0.612399 Epoch 3200 loss: -0.623650 Epoch 3250 loss: -0.598958 Epoch 3300 loss: -0.614483 Epoch 3350 loss: -0.624101 Epoch 3400 loss: -0.621168 Epoch 3450 loss: -0.605707 Epoch 3500 loss: -0.607578 Epoch 3550 loss: -0.626449 Epoch 3600 loss: -0.598677 Epoch 3650 loss: -0.631481 Epoch 3700 loss: -0.631332 Epoch 3750 loss: -0.623939 Epoch 3800 loss: -0.618339 Epoch 3850 loss: -0.623305 Epoch 3900 loss: -0.630795 Epoch 3950 loss: -0.650639 Epoch 4000 loss: -0.631177 Epoch 4050 loss: -0.634612 Epoch 4100 loss: -0.637233 Epoch 4150 loss: -0.638567 Epoch 4200 loss: -0.621981 Epoch 4250 loss: -0.613415 Epoch 4300 loss: -0.647063 Epoch 4350 loss: -0.645125 Epoch 4400 loss: -0.633680 Epoch 4450 loss: -0.614909 Epoch 4500 loss: -0.649519 Epoch 4550 loss: -0.656437 Epoch 4600 loss: -0.647346 Epoch 4650 loss: -0.643765 Epoch 4700 loss: -0.647531 Epoch 4750 loss: -0.647105 Epoch 4800 loss: -0.637083 Epoch 4850 loss: -0.666762 Epoch 4900 loss: -0.662398 Epoch 4950 loss: -0.665553 Epoch 5000 loss: -0.655768 Epoch 5050 loss: -0.649421 Epoch 5100 loss: -0.654450 Epoch 5150 loss: -0.664157 Epoch 5200 loss: -0.657926 Epoch 5250 loss: -0.659550 Epoch 5300 loss: -0.625096 Epoch 5350 loss: -0.659383 Epoch 5400 loss: -0.661412 Epoch 5450 loss: -0.678888 Epoch 5500 loss: -0.663640 Epoch 5550 loss: -0.674701 Epoch 5600 loss: -0.631961 Epoch 5650 loss: -0.684443 Epoch 5700 loss: -0.635810 Epoch 5750 loss: -0.646522 Epoch 5800 loss: -0.658710 Epoch 5850 loss: -0.668171 Epoch 5900 loss: -0.669239 Epoch 5950 loss: -0.660185 Training time: 4817.50 seconds
torch.save(celeb_encoder2.state_dict(), "celeb_encoder2_hypernet.pt")
torch.save(celeb_decoder2.state_dict(), "celeb_decoder2_hypernet.pt")
plt.plot(celeb_loss_list2)
plt.xlabel("Iterations")
plt.ylabel("loss")
Text(0, 0.5, 'loss')
plot_np_image(celeb_encoder2, celeb_decoder2, datano = 10, key = 39, context_percent = 50, title = "50% context")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Write the Random walk Metropolis Hastings algorithms from scratch.
Take 1000 samples using below given log probs and compare the mean and covariance matrix with hamiltorch’s standard HMC and emcee’s Metropolis Hastings implementation. Use 500 samples as the burn/warm up samples.
Also check the relation between acceptance ratio and the sigma of the proposal distribution in your from scratch implementation. Use the log likelihood function given below.
import torch
import torch.distributions as dist
def log_likelihood(omega):
mean = torch.tensor([0., 0.])
stddev = torch.tensor([0.5, 1.])
return dist.MultivariateNormal(mean, torch.diag(stddev**2)).log_prob(omega).sum()
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
# The grid
x1 = np.linspace(-5, 5, 100)
x2 = np.linspace(-5, 5, 100)
X1, X2 = np.meshgrid(x1, x2)
Z = np.zeros_like(X1)
Z_exp = np.zeros_like(X1)
# Evaluate the function at each point on the grid
for i in range(X1.shape[0]):
for j in range(X1.shape[1]):
Z[i, j] = log_likelihood(torch.tensor([X1[i, j], X2[i, j]]))
Z_exp[i, j] = np.exp(Z[i, j])
# Plot the contours
fig, ax = plt.subplots(figsize=(6, 4))
contour = ax.contour(X1, X2, Z, levels=50, cmap='viridis')
colorbar = plt.colorbar(contour, ax=ax, label='log likelihood')
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_title('Log likelihood contour')
Text(0.5, 1.0, 'Log likelihood contour')
# 3-D Plot of the contours
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X1, X2, Z_exp, cmap='viridis')
ax.set_xlabel('x1'); ax.set_ylabel('x2'); ax.set_zlabel('likelihood')
ax.set_title("Probability Density Function")
plt.show()
The acceptance ration of MHA is:-
$$ r = \frac{p(\theta^{*})/ J_{t}(\theta^{*}|\theta^{t-1})}{p(\theta^{t-1}) / J_{t}(\theta^{t-1}|\theta^{*})} $$Where $J_{t}(\theta^{*}|\theta^{t-1})$ is the jump distribution at time $t$.
import torch
def random_walk_metropolis_hastings(log_likelihood, initial_state, num_samples, jump_function):
samples = [initial_state]
current_state = initial_state
accepted_count = 0
for i in range(num_samples):
# Propose a new state by adding Gaussian noise
# proposal = current_state + proposal_stddev * torch.randn(current_state.shape)
proposal = jump_function(current_state).sample()
# Calculate the acceptance ratio
log_prob_current = log_likelihood(current_state)
log_prob_jump = jump_function(current_state).log_prob(proposal)
log_prob_proposal = log_likelihood(proposal)
log_prob_rev_jump = jump_function(proposal).log_prob(current_state)
acceptance_ratio = torch.exp(log_prob_proposal - log_prob_jump - log_prob_current + log_prob_rev_jump)
# Accept or reject the proposal
if torch.rand(1) < acceptance_ratio:
current_state = proposal
accepted_count += 1
else:
i -= 1
samples.append(current_state)
acceptance_rate = accepted_count / num_samples
return torch.stack(samples[500:]), torch.stack(samples), acceptance_rate
torch.manual_seed(0)
normal_stddev = torch.tensor([[0.1,0],[0,0.1]])#0.1
def gaussian_jump(current_state):
return dist.MultivariateNormal(current_state, normal_stddev)
initial_state = torch.tensor([-4, 3.5])
jump_function = gaussian_jump
num_samples = 1500
samples, all_samples, accp_rate = random_walk_metropolis_hastings(log_likelihood, initial_state, num_samples, jump_function)
accp_rate, samples.shape, all_samples.shape
(0.756, torch.Size([1001, 2]), torch.Size([1501, 2]))
def plot_samples(samples, title, lines = False):
x1 = np.linspace(-5, 5, 100)
x2 = np.linspace(-5, 5, 100)
X1, X2 = np.meshgrid(x1, x2)
Z = np.zeros_like(X1)
Z_exp = np.zeros_like(X1)
# Evaluate the function at each point on the grid
for i in range(X1.shape[0]):
for j in range(X1.shape[1]):
Z[i, j] = log_likelihood(torch.tensor([X1[i, j], X2[i, j]]))
Z_exp[i, j] = np.exp(Z[i, j])
# Plot the contours
fig, ax = plt.subplots(figsize=(8, 6))
if lines:
ax.plot(samples.numpy()[:, 0], samples.numpy()[:, 1], alpha=0.5, color = "red", label='Samples')
# else:
ax.scatter(samples.numpy()[:, 0], samples.numpy()[:, 1], s=2, alpha=0.5, label='Samples')
contour = ax.contour(X1, X2, Z, levels=50, cmap='viridis')
colorbar = plt.colorbar(contour, ax=ax, label='log likelihood')
ax.legend()
fig.suptitle(title)
ax.set_xlabel('x1')
ax.set_ylabel('x2')
plt.plot
plot_samples(all_samples, f'Random Walk Metropolis Hastings Acceptance rate = {accp_rate:.2f}', lines = True)
mean = torch.mean(samples, dim=0)
cov = torch.tensor(np.cov(samples.T))
samples.shape, mean, cov
(torch.Size([1001, 2]),
tensor([0.0460, 0.1631]),
tensor([[0.2572, 0.0197],
[0.0197, 0.7183]], dtype=torch.float64))
import hamiltorch
torch.manual_seed(0)
params_init = torch.tensor([-4, 3.5])
samples_hmc = hamiltorch.sample(log_prob_func= log_likelihood, params_init=params_init, num_samples=1500,
step_size=0.1, num_steps_per_sample=10)
Sampling (Sampler.HMC; Integrator.IMPLICIT) Time spent | Time remain.| Progress | Samples | Samples/sec 0d:00:00:09 | 0d:00:00:00 | #################### | 1500/1500 | 158.76 Acceptance Rate 0.99
samples_hmc = torch.stack(samples_hmc)[500:]
samples_hmc.shape
torch.Size([1000, 2])
hmc_mean = torch.mean(samples_hmc, dim=0)
hmc_cov = torch.tensor(np.cov(samples_hmc.T))
samples_hmc.shape, hmc_mean, hmc_cov
(torch.Size([1000, 2]),
tensor([0.0072, 0.0218]),
tensor([[0.2324, 0.0075],
[0.0075, 0.9991]], dtype=torch.float64))
plot_samples(samples_hmc, f'Hamiltorch HMC Samples, Acceptance rate = 1.00', lines = True)
import emcee
means = np.array([0, 0])
cov = np.array([[0.1, 0.0],
[0.0, 0.1]])
def log_prob(x, mu, cov):
diff = x - mu
return -0.5 * np.dot(diff, np.linalg.solve(cov, diff))
ndim = 2
nwalkers = 5
p0 = np.random.rand(nwalkers, ndim)
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, args=[means, cov])
state = sampler.run_mcmc(p0, 100)
sampler.reset()
result = sampler.run_mcmc(state, 200)
emcee_accprate = np.mean(sampler.acceptance_fraction)
print(
"Mean acceptance fraction: {0:.3f}".format(
emcee_accprate
)
)
Mean acceptance fraction: 0.728
samples_emcee = sampler.get_chain(flat=True)
samples_emcee.shape
(1000, 2)
emcee_mean = np.array([np.mean(samples_emcee[:,0]),np.mean(samples_emcee[:,1])])
emcee_cov = np.cov(samples_emcee.T)
samples_emcee.shape, emcee_mean, emcee_cov
((1000, 2),
array([-6.58492212e-05, 1.04466032e-02]),
array([[0.08967083, 0.0057612 ],
[0.0057612 , 0.09591614]]))
plot_samples(torch.tensor(samples_emcee), f'Hamiltorch HMC Samples, Acceptance rate = {emcee_accprate:.3f}', lines = True)